对PyTorch中inplace字段的全面理解
例如
torch.nn.ReLU(inplace=True)
inplace=True
表示进行原地操作,对上一层传递下来的tensor直接进行修改,如x=x+3;
inplace=False
表示新建一个变量存储操作结果,如y=x+3,x=y;
inplace=True
可以节省运算内存,不用多存储变量。
补充:PyTorch中网络里面的inplace=True字段的意思
在例如nn.LeakyReLU(inplace=True)中的inplace字段是什么意思呢?有什么用?
inplace=True的意思是进行原地操作,例如x=x+5,对x就是一个原地操作,y=x+5,x=y,完成了与x=x+5同样的功能但是不是原地操作。
上面LeakyReLU中的inplace=True的含义是一样的,是对于Conv2d这样的上层网络传递下来的tensor直接进行修改,好处就是可以节省运算内存,不用多储存变量y。
inplace=True means that it will modify the input directly, without allocating any additional output. It can sometimes slightly decrease the memory usage, but may not always be a valid operation (because the original input is destroyed). However, if you don't see an error, it means that your use case is valid.
以上为个人经验,希望能给大家一个参考,也希望大家多多支持我们。
相关推荐
-
浅谈PyTorch中in-place operation的含义
in-place operation在pytorch中是指改变一个tensor的值的时候,不经过复制操作,而是直接在原来的内存上改变它的值.可以把它成为原地操作符. 在pytorch中经常加后缀"_"来代表原地in-place operation,比如说.add_() 或者.scatter().python里面的+=,*=也是in-place operation. 下面是正常的加操作,执行结束加操作之后x的值没有发生变化: import torch x=torch.rand(2) #t
-
PyTorch中的拷贝与就地操作详解
前言 PyTroch中我们经常使用到Numpy进行数据的处理,然后再转为Tensor,但是关系到数据的更改时我们要注意方法是否是共享地址,这关系到整个网络的更新.本篇就In-palce操作,拷贝操作中的注意点进行总结. In-place操作 pytorch中原地操作的后缀为_,如.add_()或.scatter_(),就地操作是直接更改给定Tensor的内容而不进行复制的操作,即不会为变量分配新的内存.Python操作类似+=或*=也是就地操作.(我加了我自己~) 为什么in-place操作可以
-
pytorch中的自定义数据处理详解
pytorch在数据中采用Dataset的数据保存方式,需要继承data.Dataset类,如果需要自己处理数据的话,需要实现两个基本方法. :.getitem:返回一条数据或者一个样本,obj[index] = obj.getitem(index). :.len:返回样本的数量 . len(obj) = obj.len(). Dataset 在data里,调用的时候使用 from torch.utils import data import os from PIL import Image 数
-
对PyTorch中inplace字段的全面理解
例如 torch.nn.ReLU(inplace=True) inplace=True 表示进行原地操作,对上一层传递下来的tensor直接进行修改,如x=x+3: inplace=False 表示新建一个变量存储操作结果,如y=x+3,x=y: inplace=True 可以节省运算内存,不用多存储变量. 补充:PyTorch中网络里面的inplace=True字段的意思 在例如nn.LeakyReLU(inplace=True)中的inplace字段是什么意思呢?有什么用? inplace=
-
Pytorch中index_select() 函数的实现理解
函数形式: index_select( dim, index ) 参数: dim:表示从第几维挑选数据,类型为int值: index:表示从第一个参数维度中的哪个位置挑选数据,类型为torch.Tensor类的实例: 刚开始学习pytorch,遇到了index_select(),一开始不太明白几个参数的意思,后来查了一下资料,算是明白了一点. a = torch.linspace(1, 12, steps=12).view(3, 4) print(a) b = torch.index_selec
-
pytorch中torch.topk()函数的快速理解
目录 函数作用: 举个栗子: 实例演示 总结 函数作用: 该函数的作用即按字面意思理解,topk:取数组的前k个元素进行排序. 通常该函数返回2个值,第一个值为排序的数组,第二个值为该数组中获取到的元素在原数组中的位置标号. 举个栗子: import numpy as np import torch import torch.utils.data.dataset as Dataset from torch.utils.data import Dataset,DataLoader ########
-
关于pytorch中网络loss传播和参数更新的理解
相比于2018年,在ICLR2019提交论文中,提及不同框架的论文数量发生了极大变化,网友发现,提及tensorflow的论文数量从2018年的228篇略微提升到了266篇,keras从42提升到56,但是pytorch的数量从87篇提升到了252篇. TensorFlow: 228--->266 Keras: 42--->56 Pytorch: 87--->252 在使用pytorch中,自己有一些思考,如下: 1. loss计算和反向传播 import torch.nn as nn
-
对python pandas中 inplace 参数的理解
pandas 中 inplace 参数在很多函数中都会有,它的作用是:是否在原对象基础上进行修改 inplace = True:不创建新的对象,直接对原始对象进行修改: inplace = False:对数据进行修改,创建并返回新的对象承载其修改结果. 默认是False,即创建新的对象进行修改,原对象不变,和深复制和浅复制有些类似. 例: inplace=True情况: import pandas as pd import numpy as np df=pd.DataFrame(np.rand
-
对Pytorch 中的contiguous理解说明
最近遇到这个函数,但查的中文博客里的解释貌似不是很到位,这里翻译一下stackoverflow上的回答并加上自己的理解. 在pytorch中,只有很少几个操作是不改变tensor的内容本身,而只是重新定义下标与元素的对应关系的.换句话说,这种操作不进行数据拷贝和数据的改变,变的是元数据. 这些操作是: narrow(),view(),expand()和transpose() 举个栗子,在使用transpose()进行转置操作时,pytorch并不会创建新的.转置后的tensor,而是修改了ten
-
深入理解PyTorch中的nn.Embedding的使用
目录 一.前置知识 1.1 语料库(Corpus) 1.2 词元(Token) 1.3 词表(Vocabulary) 二.nn.Embedding 基础 2.1 为什么要 embedding? 2.2 基础参数 2.3 nn.Embedding 与 nn.Linear 的区别 2.4 nn.Embedding 的更新问题 三.nn.Embedding 进阶 3.1 全部参数 3.2 使用预训练的词嵌入 四.最后 一.前置知识 1.1 语料库(Corpus) 太长不看版: NLP任务所依赖的语言数
-
对pytorch中x = x.view(x.size(0), -1) 的理解说明
在pytorch的CNN代码中经常会看到 x.view(x.size(0), -1) 首先,在pytorch中的view()函数就是用来改变tensor的形状的,例如将2行3列的tensor变为1行6列,其中-1表示会自适应的调整剩余的维度 a = torch.Tensor(2,3) print(a) # tensor([[0.0000, 0.0000, 0.0000], # [0.0000, 0.0000, 0.0000]]) print(a.view(1,-1)) # tensor([[0.
-
浅谈Pytorch中autograd的若干(踩坑)总结
关于Variable和Tensor 旧版本的Pytorch中,Variable是对Tensor的一个封装:在Pytorch大于v0.4的版本后,Varible和Tensor合并了,意味着Tensor可以像旧版本的Variable那样运行,当然新版本中Variable封装仍旧可以用,但是对Varieble操作返回的将是一个Tensor. import torch as t from torch.autograd import Variable a = t.ones(3,requires_grad=
随机推荐
- AngularJS实践之使用NgModelController进行数据绑定
- 安装SQL2005时出现的版本变更检查SKUUPGRADE=1问题的解决方法
- MongoDB快速入门笔记(三)之MongoDB插入文档操作
- C#中的委托和事件学习(续)
- Android实现拍照、选择图片并裁剪图片功能
- 通过Ajax手动解决WordPress WP-PostViews不计数的问题
- 深入理解Mysql的四种隔离级别
- Bootstrap开发实战之响应式轮播图
- node.js 抓取代理ip实例代码
- 关于Ajax技术原理的3点总结
- CentOS 安装软件出现错误:/lib/ld-linux.so.2: bad ELF interpreter 解决
- python正则匹配抓取豆瓣电影链接和评论代码分享
- Win2008 网络策略设置方法 让访问更安全
- JavaScript学习笔记整理_用于模式匹配的String方法
- Java实现ftp文件上传下载解决慢中文乱码多个文件下载等问题
- 浅析C和C++函数的相互引用
- C语言基础之malloc和free函数详解
- js阻止默认右键的下拉菜单方法
- Android 2018最新手机号验证正则表达式方法
- Nginx服务优化配置方案