简单聊聊PyTorch里面的torch.nn.Parameter()
在刷官方Tutorial的时候发现了一个用法self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size)),看了官方教程里面的解释也是云里雾里,于是在栈溢网看到了一篇解释,并做了几个实验才算完全理解了这个函数。首先可以把这个函数理解为类型转换函数,将一个不可训练的类型Tensor转换成可以训练的类型parameter并将这个parameter绑定到这个module里面(net.parameter()中就有这个绑定的parameter,所以在参数优化的时候可以进行优化的),所以经过类型转换这个self.v变成了模型的一部分,成为了模型中根据训练可以改动的参数了。使用这个函数的目的也是想让某些变量在学习的过程中不断的修改其值以达到最优化。
出现这个函数的地方
在concat注意力机制中,权值V是不断学习的所以要是parameter类型。
通过做下面的实验发现,linear里面的weight和bias就是parameter类型,且不能够使用tensor类型替换,还有linear里面的weight甚至可能通过指定一个不同于初始化时候的形状进行模型的更改。
做的实验
self.v被绑定到模型中了,所以可以在训练的时候优化
与torch.tensor([1,2,3],requires_grad=True)的区别,这个只是将参数变成可训练的,并没有绑定在module的parameter列表中。
总结
到此这篇关于PyTorch里面的torch.nn.Parameter()的文章就介绍到这了,更多相关PyTorch的torch.nn.Parameter()内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!
相关推荐
-
PyTorch里面的torch.nn.Parameter()详解
在看过很多博客的时候发现了一个用法self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size)),首先可以把这个函数理解为类型转换函数,将一个不可训练的类型Tensor转换成可以训练的类型parameter并将这个parameter绑定到这个module里面(net.parameter()中就有这个绑定的parameter,所以在参数优化的时候可以进行优化的),所以经过类型转换这个self.v变成了模型的一部分,成为了模型中根据训练可以改动
-
简单聊聊PyTorch里面的torch.nn.Parameter()
在刷官方Tutorial的时候发现了一个用法self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size)),看了官方教程里面的解释也是云里雾里,于是在栈溢网看到了一篇解释,并做了几个实验才算完全理解了这个函数.首先可以把这个函数理解为类型转换函数,将一个不可训练的类型Tensor转换成可以训练的类型parameter并将这个parameter绑定到这个module里面(net.parameter()中就有这个绑定的parameter,所以在
-
PyTorch中的参数类torch.nn.Parameter()详解
目录 前言 分析 ViT中nn.Parameter()的实验 其他解释 参考: 总结 前言 今天来聊一下PyTorch中的torch.nn.Parameter()这个函数,笔者第一次见的时候也是大概能理解函数的用途,但是具体实现原理细节也是云里雾里,在参考了几篇博文,做过几个实验之后算是清晰了,本文在记录的同时希望给后来人一个参考,欢迎留言讨论. 分析 先看其名,parameter,中文意为参数.我们知道,使用PyTorch训练神经网络时,本质上就是训练一个函数,这个函数输入一个数据(如CV中输
-
PyTorch基础之torch.nn.Conv2d中自定义权重问题
目录 torch.nn.Conv2d中自定义权重 torch.nn.Conv2d()用法讲解 用法 参数 相关形状 总结 torch.nn.Conv2d中自定义权重 torch.nn.Conv2d函数调用后会自动初始化weight和bias,本文主要涉及 如何自定义weight和bias为需要的数均分布类型: torch.nn.Conv2d.weight.data以及torch.nn.Conv2d.bias.data为torch.tensor类型,因此只要对这两个属性进行操作即可. [sampl
-
pytorch中的torch.nn.Conv2d()函数图文详解
目录 一.官方文档介绍 二.torch.nn.Conv2d()函数详解 参数dilation——扩张卷积(也叫空洞卷积) 参数groups——分组卷积 总结 一.官方文档介绍 官网 nn.Conv2d:对由多个输入平面组成的输入信号进行二维卷积 二.torch.nn.Conv2d()函数详解 参数详解 torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1,
-
PyTorch中torch.nn.Linear实例详解
目录 前言 1. nn.Linear的原理: 2. nn.Linear的使用: 3. nn.Linear的源码定义: 补充:许多细节需要声明 总结 前言 在学习transformer时,遇到过非常频繁的nn.Linear()函数,这里对nn.Linear进行一个详解.参考:https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html 1. nn.Linear的原理: 从名称就可以看出来,nn.Linear表示的是线性变
-
Pytorch中torch.nn.Softmax的dim参数用法说明
Pytorch中torch.nn.Softmax的dim参数使用含义 涉及到多维tensor时,对softmax的参数dim总是很迷,下面用一个例子说明 import torch.nn as nn m = nn.Softmax(dim=0) n = nn.Softmax(dim=1) k = nn.Softmax(dim=2) input = torch.randn(2, 2, 3) print(input) print(m(input)) print(n(input)) print(k(inp
-
pytorch torch.nn.AdaptiveAvgPool2d()自适应平均池化函数详解
如题:只需要给定输出特征图的大小就好,其中通道数前后不发生变化.具体如下: AdaptiveAvgPool2d CLASStorch.nn.AdaptiveAvgPool2d(output_size)[SOURCE] Applies a 2D adaptive average pooling over an input signal composed of several input planes. The output is of size H x W, for any input size.
-
Pytorch - TORCH.NN.INIT 参数初始化的操作
路径: https://pytorch.org/docs/master/nn.init.html#nn-init-doc 初始化函数:torch.nn.init # -*- coding: utf-8 -*- """ Created on 2019 @author: fancp """ import torch import torch.nn as nn w = torch.empty(3,5) #1.均匀分布 - u(a,b) #torch.n
-
Pytorch中torch.flatten()和torch.nn.Flatten()实例详解
torch.flatten(x)等于torch.flatten(x,0)默认将张量拉成一维的向量,也就是说从第一维开始平坦化,torch.flatten(x,1)代表从第二维开始平坦化. import torch x=torch.randn(2,4,2) print(x) z=torch.flatten(x) print(z) w=torch.flatten(x,1) print(w) 输出为: tensor([[[-0.9814, 0.8251], [ 0.8197, -1.0426], [-
随机推荐
- 一个关于正则表达式的问题
- struts2+jquery组合验证注册用户是否存在
- Java 中This用法的实例详解
- java获取注册ip实例
- 详解Javascript继承的实现
- ScriptManager.RegisterStartupScript()方法在ajax页面无效的解决方法
- 分享十款最出色的PHP安全开发库中文详细介绍
- Python使用struct处理二进制的实例详解
- C++编程中使用设计模式中的policy策略模式的实例讲解
- CSS 中关于字体处理效果的思考
- 详解如何通过Mysql的二进制日志恢复数据库数据
- thinkphp 抓取网站的内容并且保存到本地的实例详解
- java实现在复制文件时使用进度条(java实现进度条)
- 用while判断输入的数字是否回文数的简单实现
- 网络管理之IP地址篇
- Python切片索引用法示例
- Java设计模式之享元模式
- 易语言次循环和延迟循环实现方法
- Java通过调用FFMPEG获取视频时长
- 使用Python实现租车计费系统的两种方法