PyTorch中的参数类torch.nn.Parameter()详解

目录
  • 前言
  • 分析
  • ViT中nn.Parameter()的实验
  • 其他解释
  • 参考:
  • 总结

前言

今天来聊一下PyTorch中的torch.nn.Parameter()这个函数,笔者第一次见的时候也是大概能理解函数的用途,但是具体实现原理细节也是云里雾里,在参考了几篇博文,做过几个实验之后算是清晰了,本文在记录的同时希望给后来人一个参考,欢迎留言讨论。

分析

先看其名,parameter,中文意为参数。我们知道,使用PyTorch训练神经网络时,本质上就是训练一个函数,这个函数输入一个数据(如CV中输入一张图像),输出一个预测(如输出这张图像中的物体是属于什么类别)。而在我们给定这个函数的结构(如卷积、全连接等)之后,能学习的就是这个函数的参数了,我们设计一个损失函数,配合梯度下降法,使得我们学习到的函数(神经网络)能够尽量准确地完成预测任务。

通常,我们的参数都是一些常见的结构(卷积、全连接等)里面的计算参数。而当我们的网络有一些其他的设计时,会需要一些额外的参数同样很着整个网络的训练进行学习更新,最后得到最优的值,经典的例子有注意力机制中的权重参数、Vision Transformer中的class token和positional embedding等。

而这里的torch.nn.Parameter()就可以很好地适应这种应用场景。

下面是这篇博客的一个总结,笔者认为讲的比较明白,在这里引用一下:

首先可以把这个函数理解为类型转换函数,将一个不可训练的类型Tensor转换成可以训练的类型parameter并将这个parameter绑定到这个module里面(net.parameter()中就有这个绑定的parameter,所以在参数优化的时候可以进行优化的),所以经过类型转换这个self.v变成了模型的一部分,成为了模型中根据训练可以改动的参数了。使用这个函数的目的也是想让某些变量在学习的过程中不断的修改其值以达到最优化。

ViT中nn.Parameter()的实验

看过这个分析后,我们再看一下Vision Transformer中的用法:

...

self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
...

我们知道在ViT中,positonal embedding和class token是两个需要随着网络训练学习的参数,但是它们又不属于FC、MLP、MSA等运算的参数,在这时,就可以用nn.Parameter()来将这个随机初始化的Tensor注册为可学习的参数Parameter。

为了确定这两个参数确实是被添加到了net.Parameters()内,笔者稍微改动源码,显式地指定这两个参数的初始数值为0.98,并打印迭代器net.Parameters()。

...

self.pos_embedding = nn.Parameter(torch.ones(1, num_patches+1, dim) * 0.98)
self.cls_token = nn.Parameter(torch.ones(1, 1, dim) * 0.98)
...

实例化一个ViT模型并打印net.Parameters():

net_vit = ViT(
        image_size = 256,
        patch_size = 32,
        num_classes = 1000,
        dim = 1024,
        depth = 6,
        heads = 16,
        mlp_dim = 2048,
        dropout = 0.1,
        emb_dropout = 0.1
    )

for para in net_vit.parameters():
        print(para.data)

输出结果中可以看到,最前两行就是我们显式指定为0.98的两个参数pos_embedding和cls_token:

tensor([[[0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800],
         [0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800],
         [0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800],
         ...,
         [0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800],
         [0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800],
         [0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800]]])
tensor([[[0.9800, 0.9800, 0.9800,  ..., 0.9800, 0.9800, 0.9800]]])
tensor([[-0.0026, -0.0064,  0.0111,  ...,  0.0091, -0.0041, -0.0060],
        [ 0.0003,  0.0115,  0.0059,  ..., -0.0052, -0.0056,  0.0010],
        [ 0.0079,  0.0016, -0.0094,  ...,  0.0174,  0.0065,  0.0001],
        ...,
        [-0.0110, -0.0137,  0.0102,  ...,  0.0145, -0.0105, -0.0167],
        [-0.0116, -0.0147,  0.0030,  ...,  0.0087,  0.0022,  0.0108],
        [-0.0079,  0.0033, -0.0087,  ..., -0.0174,  0.0103,  0.0021]])
...
...

这就可以确定nn.Parameter()添加的参数确实是被添加到了Parameters列表中,会被送入优化器中随训练一起学习更新。

from torch.optim import Adam
opt = Adam(net_vit.parameters(), learning_rate=0.001)

其他解释

以下是国外StackOverflow的一个大佬的解读,笔者自行翻译并放在这里供大家参考,想查看原文的同学请戳这里。

我们知道Tensor相当于是一个高维度的矩阵,它是Variable类的子类。Variable和Parameter之间的差异体现在与Module关联时。当Parameter作为model的属性与module相关联时,它会被自动添加到Parameters列表中,并且可以使用net.Parameters()迭代器进行访问。

最初在Torch中,一个Variable(例如可以是某个中间state)也会在赋值时被添加为模型的Parameter。在某些实例中,需要缓存变量,而不是将它们添加到Parameters列表中。

文档中提到的一种情况是RNN,在这种情况下,您需要保存最后一个hidden state,这样就不必一次又一次地传递它。需要缓存一个Variable,而不是让它自动注册为模型的Parameter,这就是为什么我们有一个显式的方法将参数注册到我们的模型,即nn.Parameter类。

举个例子:

import torch
import torch.nn as nn
from torch.optim import Adam

class NN_Network(nn.Module):
    def __init__(self,in_dim,hid,out_dim):
        super(NN_Network, self).__init__()
        self.linear1 = nn.Linear(in_dim,hid)
        self.linear2 = nn.Linear(hid,out_dim)
        self.linear1.weight = torch.nn.Parameter(torch.zeros(in_dim,hid))
        self.linear1.bias = torch.nn.Parameter(torch.ones(hid))
        self.linear2.weight = torch.nn.Parameter(torch.zeros(in_dim,hid))
        self.linear2.bias = torch.nn.Parameter(torch.ones(hid))

    def forward(self, input_array):
        h = self.linear1(input_array)
        y_pred = self.linear2(h)
        return y_pred

in_d = 5
hidn = 2
out_d = 3
net = NN_Network(in_d, hidn, out_d)

然后检查一下这个模型的Parameters列表:

for param in net.parameters():
    print(type(param.data), param.size())

""" Output
<class 'torch.FloatTensor'> torch.Size([5, 2])
<class 'torch.FloatTensor'> torch.Size([2])
<class 'torch.FloatTensor'> torch.Size([5, 2])
<class 'torch.FloatTensor'> torch.Size([2])
"""

可以轻易地送入到优化器中:

opt = Adam(net.parameters(), learning_rate=0.001)

另外,请注意Parameter的require_grad会自动设定。

各位读者有疑惑或异议的地方,欢迎留言讨论。

参考:

https://www.jb51.net/article/238632.htm

https://stackoverflow.com/questions/50935345/understanding-torch-nn-parameter

总结

到此这篇关于PyTorch中torch.nn.Parameter()的文章就介绍到这了,更多相关PyTorch中torch.nn.Parameter()内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • PyTorch里面的torch.nn.Parameter()详解

    在看过很多博客的时候发现了一个用法self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size)),首先可以把这个函数理解为类型转换函数,将一个不可训练的类型Tensor转换成可以训练的类型parameter并将这个parameter绑定到这个module里面(net.parameter()中就有这个绑定的parameter,所以在参数优化的时候可以进行优化的),所以经过类型转换这个self.v变成了模型的一部分,成为了模型中根据训练可以改动

  • PyTorch中的参数类torch.nn.Parameter()详解

    目录 前言 分析 ViT中nn.Parameter()的实验 其他解释 参考: 总结 前言 今天来聊一下PyTorch中的torch.nn.Parameter()这个函数,笔者第一次见的时候也是大概能理解函数的用途,但是具体实现原理细节也是云里雾里,在参考了几篇博文,做过几个实验之后算是清晰了,本文在记录的同时希望给后来人一个参考,欢迎留言讨论. 分析 先看其名,parameter,中文意为参数.我们知道,使用PyTorch训练神经网络时,本质上就是训练一个函数,这个函数输入一个数据(如CV中输

  • PyTorch中clone()、detach()及相关扩展详解

    clone() 与 detach() 对比 Torch 为了提高速度,向量或是矩阵的赋值是指向同一内存的,这不同于 Matlab.如果需要保存旧的tensor即需要开辟新的存储地址而不是引用,可以用 clone() 进行深拷贝, 首先我们来打印出来clone()操作后的数据类型定义变化: (1). 简单打印类型 import torch a = torch.tensor(1.0, requires_grad=True) b = a.clone() c = a.detach() a.data *=

  • pytorch中tensor.expand()和tensor.expand_as()函数详解

    tensor.expend()函数 >>> import torch >>> a=torch.tensor([[2],[3],[4]]) >>> print(a.size()) torch.Size([3, 1]) >>> a.expand(3,2) tensor([[2, 2], [3, 3], [4, 4]]) >>> a tensor([[2], [3], [4]]) 可以看出expand()函数括号里面为变形

  • python 中Mixin混入类的使用方法详解

    目录 前言 Mixin 与继承的区别 总结 前言 最近在看sanic的源码,发现有很多Mixin的类,大概长成这个样子 class BaseSanic(    RouteMixin,    MiddlewareMixin,    ListenerMixin,    ExceptionMixin,    SignalMixin,    metaclass=SanicMeta, ): 于是对于这种 Mixin 研究了一下,其实也没什么新的东西,Mixin 又称混入,只是一种编程思想的体现,但是在使用

  • C语言中函数参数的入栈顺序详解及实例

    C语言中函数参数的入栈顺序详解及实例 对技术执着的人,比如说我,往往对一些问题,不仅想做到"知其然",还想做到"知其所以然".C语言可谓博大精深,即使我已经有多年的开发经验,可还是有许多问题不知其所以然.某天某地某人问我,C语言中函数参数的入栈顺序如何?从右至左,我随口回答.为什么是从右至左呢?我终究没有给出合理的解释.于是,只好做了个作业,于是有了这篇小博文. #include void foo(int x, int y, int z) { printf(&quo

  • QT中对Mat类的一些操作详解

    目录 一.类型转换 二.保存至数据库 一.类型转换 opencv在QT中的应用通常会涉及到这三者的转换,即Mat.QImage.QPixmap.下面分别给出了 Mat转QImage QImage转Mat Mat转QPixmap 1️⃣:Mat转QImage QImage MainWindow::MatToImage(const Mat &m) //Mat转Image { switch(m.type()) { case CV_8UC1: { QImage img((uchar *)m.data,m

  • Java中的Collections类的使用示例详解

    Collections的常用方法及其简单使用 代码如下: package Collections; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Stack; public class collections { public static void main(String[]args){ int array[]={125,75,56,7}; Li

  • Spring Mvc中传递参数方法之url/requestMapping详解

    前言 相信大家在使用spring的项目中,前台传递参数到后台是经常遇到的事, 我们必须熟练掌握一些常用的参数传递方式和注解的使用,本文将给大家介绍关于Spring Mvc中传递参数方法之url/requestMapping的相关内容,分享出来供大家参考学习,话不多说,直接上正文. 方法如下 1. @requestMapping: 类级别和方法级别的注解, 指明前后台解析的路径. 有value属性(一个参数时默认)指定url路径解析,method属性指定提交方式(默认为get提交) @Reques

  • Pytorch中Tensor与各种图像格式的相互转化详解

    前言 在pytorch中经常会遇到图像格式的转化,例如将PIL库读取出来的图片转化为Tensor,亦或者将Tensor转化为numpy格式的图片.而且使用不同图像处理库读取出来的图片格式也不相同,因此,如何在pytorch中正确转化各种图片格式(PIL.numpy.Tensor)是一个在调试中比较重要的问题. 本文主要说明在pytorch中如何正确将图片格式在各种图像库读取格式以及tensor向量之间转化的问题.以下代码经过测试都可以在Pytorch-0.4.0或0.3.0版本直接使用. 对py

随机推荐