Pytorch中的自动求梯度机制和Variable类实例
自动求导机制是每一个深度学习框架中重要的性质,免去了手动计算导数,下面用代码介绍并举例说明Pytorch的自动求导机制。
首先介绍Variable,Variable是对Tensor的一个封装,操作和Tensor是一样的,但是每个Variable都有三个属性:Varibale的Tensor本身的.data,对应Tensor的梯度.grad,以及这个Variable是通过什么方式得到的.grad_fn,根据最新消息,在pytorch0.4更新后,torch和torch.autograd.Variable现在是同一类。torch.Tensor能像Variable那样追踪历史和反向传播。Variable仍能正确工作,但是返回的是Tensor。
我们拥抱这些新特性,看看Pytorch怎么进行自动求梯度。
#encoding:utf-8 import torch x = torch.tensor([2.],requires_grad=True) #新建一个tensor,允许自动求梯度,这一项默认是false. y = (x+2)**2 + 3 #y的表达式中包含x,因此y能进行自动求梯度 y.backward() print(x.grad)
输出结果是:
tensor([8.])
这里添加一个小知识点,即torch.Tensor和torch.tensor的不同。二者均可以生成新的张量,但torch.Tensor()是python类,是默认张量类型torch.FloatTensor()的别名,使用torch.Tensor()会调用构造函数,生成单精度浮点类型的张量。
而torch.tensor()是函数,其中data可以是list,tuple,numpy,ndarray,scalar和其他类型,但只有浮点类型的张量能够自动求梯度。
torch.tensor(data, dtype=None, device=None, requires_grad=False)
言归正传,上一个例子的变量本质上是标量。下面一个例子对矩阵求导。
#encoding:utf-8 import torch x = torch.ones((2,4),requires_grad=True) y = torch.ones((2,1),requires_grad=True) W = torch.ones((4,1),requires_grad=True) J = torch.sum(y - torch.matmul(x,W)) #torch.matmul()表示对矩阵作乘法 J.backward() print(x.grad) print(y.grad) print(W.grad)
输出结果是:
tensor([[-1., -1., -1., -1.], [-1., -1., -1., -1.]]) tensor([[1.], [1.]]) tensor([[-2.], [-2.], [-2.], [-2.]])
以上这篇Pytorch中的自动求梯度机制和Variable类实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。
相关推荐
-
在pytorch中对非叶节点的变量计算梯度实例
在pytorch中一般只对叶节点进行梯度计算,也就是下图中的d,e节点,而对非叶节点,也即是c,b节点则没有显式地去保留其中间计算过程中的梯度(因为一般来说只有叶节点才需要去更新),这样可以节省很大部分的显存,但是在调试过程中,有时候我们需要对中间变量梯度进行监控,以确保网络的有效性,这个时候我们需要打印出非叶节点的梯度,为了实现这个目的,我们可以通过两种手段进行. 注册hook函数 Tensor.register_hook[2] 可以注册一个反向梯度传导时的hook函数,这个hook函数将会在
-
在pytorch中实现只让指定变量向后传播梯度
pytorch中如何只让指定变量向后传播梯度? (或者说如何让指定变量不参与后向传播?) 有以下公式,假如要让L对xvar求导: (1)中,L对xvar的求导将同时计算out1部分和out2部分: (2)中,L对xvar的求导只计算out2部分,因为out1的requires_grad=False: (3)中,L对xvar的求导只计算out1部分,因为out2的requires_grad=False: 验证如下: #!/usr/bin/env python2 # -*- coding: utf-
-
pytorch对梯度进行可视化进行梯度检查教程
目的: 在训练神经网络的时候,有时候需要自己写操作,比如faster_rcnn中的roi_pooling,我们可以可视化前向传播的图像和反向传播的梯度图像,前向传播可以检查流程和计算的正确性,而反向传播则可以大概检查流程的正确性. 实验 可视化rroi_align的梯度 1.pytorch 0.4.1及之前,需要声明需要参数,这里将图片数据声明为variable im_data = Variable(im_data, requires_grad=True) 2.进行前向传播,最后的loss映射为
-
PyTorch的SoftMax交叉熵损失和梯度用法
在PyTorch中可以方便的验证SoftMax交叉熵损失和对输入梯度的计算 关于softmax_cross_entropy求导的过程,可以参考HERE 示例: # -*- coding: utf-8 -*- import torch import torch.autograd as autograd from torch.autograd import Variable import torch.nn.functional as F import torch.nn as nn import nu
-
pytorch梯度剪裁方式
我就废话不多说,看例子吧! import torch.nn as nn outputs = model(data) loss= loss_fn(outputs, target) optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2) optimizer.step() nn.utils.clip_grad_norm_ 的参数: param
-
pytorch的梯度计算以及backward方法详解
基础知识 tensors: tensor在pytorch里面是一个n维数组.我们可以通过指定参数reuqires_grad=True来建立一个反向传播图,从而能够计算梯度.在pytorch中一般叫做dynamic computation graph(DCG)--即动态计算图. import torch import numpy as np # 方式一 x = torch.randn(2,2, requires_grad=True) # 方式二 x = torch.autograd.Variabl
-
Pytorch中的自动求梯度机制和Variable类实例
自动求导机制是每一个深度学习框架中重要的性质,免去了手动计算导数,下面用代码介绍并举例说明Pytorch的自动求导机制. 首先介绍Variable,Variable是对Tensor的一个封装,操作和Tensor是一样的,但是每个Variable都有三个属性:Varibale的Tensor本身的.data,对应Tensor的梯度.grad,以及这个Variable是通过什么方式得到的.grad_fn,根据最新消息,在pytorch0.4更新后,torch和torch.autograd.Variab
-
浅谈Pytorch中的自动求导函数backward()所需参数的含义
正常来说backward( )函数是要传入参数的,一直没弄明白backward需要传入的参数具体含义,但是没关系,生命在与折腾,咱们来折腾一下,嘿嘿. 对标量自动求导 首先,如果out.backward()中的out是一个标量的话(相当于一个神经网络有一个样本,这个样本有两个属性,神经网络有一个输出)那么此时我的backward函数是不需要输入任何参数的. import torch from torch.autograd import Variable a = Variable(torch.Te
-
关于PyTorch 自动求导机制详解
自动求导机制 从后向中排除子图 每个变量都有两个标志:requires_grad和volatile.它们都允许从梯度计算中精细地排除子图,并可以提高效率. requires_grad 如果有一个单一的输入操作需要梯度,它的输出也需要梯度.相反,只有所有输入都不需要梯度,输出才不需要.如果其中所有的变量都不需要梯度进行,后向计算不会在子图中执行. >>> x = Variable(torch.randn(5, 5)) >>> y = Variable(torch.rand
-
Java 中利用泛型和反射机制抽象DAO的实例
Java 中利用泛型和反射机制抽象DAO的实例 一般的DAO都有CRUD操作,在每个实体DAO接口中重复定义这些方法,不如提供一个通用的DAO接口,具体的实体DAO可以扩展这个通用DAO以提供特殊的操作,从而将DAO抽象到另一层次,令代码质量有很好的提升 1.通用接口 import java.io.Serializable; import java.util.List; public interface BaseDao<T> { T get(Serializable id); List<
-
C语言中判断素数(求素数)的思路与方法实例
目录 前言 思路1实现: 思路2实现: <C与指针>4.14-2: 补充:判断素数的4种方法实例 总结 前言 素数又称质数.所谓素数是指除了 1 和它本身以外,不能被任何整数整除的数,例如17就是素数,因为它不能被 2~16 的任一整数整除. 思路1):因此判断一个整数m是否是素数,只需把 m 被 2 ~ m-1 之间的每一个整数去除,如果都不能被整除,那么 m 就是一个素数. 思路2):判断方法还可以简化.m 不必被 2 ~ m-1 之间的每一个整数去除,只需被 2 ~ 之间的每一个整数去
-
pytorch如何定义新的自动求导函数
目录 pytorch定义新的自动求导函数 pytorch自动求导与逻辑回归 自动求导 逻辑回归 总结 pytorch定义新的自动求导函数 在pytorch中想自定义求导函数,通过实现torch.autograd.Function并重写forward和backward函数,来定义自己的自动求导运算.参考官网上的demo:传送门 直接上代码,定义一个ReLu来实现自动求导 import torch class MyRelu(torch.autograd.Function): @staticmetho
-
TensorFlow的自动求导原理分析
原理: TensorFlow使用的求导方法称为自动微分(Automatic Differentiation),它既不是符号求导也不是数值求导,而类似于将两者结合的产物. 最基本的原理就是链式法则,关键思想是在基本操作(op)的水平上应用符号求导,并保持中间结果(grad). 基本操作的符号求导定义在\tensorflow\python\ops\math_grad.py文件中,这个文件中的所有函数都用RegisterGradient装饰器包装了起来,这些函数都接受两个参数op和grad,参数op是
-
BatchNorm2d原理、作用及pytorch中BatchNorm2d函数的参数使用
目录 BN原理.作用 函数参数讲解 总结 BN原理.作用 函数参数讲解 BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 1.num_features:一般输入参数的shape为batch_size*num_features*height*width,即为其中特征的数量,即为输入BN层的通道数: 2.eps:分母中添加的一个值,目的是为了计算的稳定性,默认为:1e-5,避免分母为0:
-
Pytorch中关于BatchNorm2d的参数解释
目录 BatchNorm2d中的track_running_stats参数 running_mean和running_var参数 BatchNorm2d参数讲解 总结 BatchNorm2d中的track_running_stats参数 如果BatchNorm2d的参数val,track_running_stats设置False,那么加载预训练后每次模型测试测试集的结果时都不一样: track_running_stats设置为True时,每次得到的结果都一样. running_mean和runn
-
在 pytorch 中实现计算图和自动求导
前言: 今天聊一聊 pytorch 的计算图和自动求导,我们先从一个简单例子来看,下面是一个简单函数建立了 yy 和 xx 之间的关系 然后我们结点和边形式表示上面公式: 上面的式子可以用图的形式表达,接下来我们用 torch 来计算 x 导数,首先我们创建一个 tensor 并且将其requires_grad设置为True表示随后反向传播会对其进行求导. x = torch.tensor(3.,requires_grad=True) 然后写出 y = 3*x**2 + 4*x + 2 y.ba
随机推荐
- 详解Spring中bean生命周期回调方法
- iOS 原生实现扫描二维码和条形码功能限制扫描区域
- 在Mac OS上搭建PHP的Yii框架及相关测试环境
- thinkphp配置连接数据库技巧
- Js控制弹窗实现在任意分辨率下居中显示
- 在Python中使用PIL模块对图片进行高斯模糊处理的教程
- Android如何实现压缩和解压缩文件
- 解决SQL Server无法启动的小技巧
- jquery.validate分组验证代码
- Bootstrap CSS布局之表格
- listview改变字体大小实例讲解
- Windows2003网络服务器安全攻略
- C++实现查找中位数的O(N)算法和Kmin算法
- 一段采集程序代码
- Gbs法国20GB免费空间支持PHP和mysql
- vue中mint-ui的使用方法
- vue配置文件实现代理v2版本的方法
- pycharm设置鼠标悬停查看方法设置
- js 数组详细操作方法及解析合集
- 在Python中居然可以定义两个同名通参数的函数