pytorch中Parameter函数用法示例

目录
  • 用法介绍
  • 代码介绍

用法介绍

pytorch中的Parameter函数可以对某个张量进行参数化。它可以将不可训练的张量转化为可训练的参数类型,同时将转化后的张量绑定到模型可训练参数的列表中,当更新模型的参数时一并将其更新。

torch.nn.parameter.Parameter

  • data (Tensor):表示需要参数化的张量
  • requires_grad (bool, optional):表示是否该张量是否需要梯度,默认值为True

代码介绍

 pytorch中的Parameter函数具体的代码示例如下所示

import torch
import torch.nn as nn
class NeuralNetwork(nn.Module):
	def __init__(self, input_dim, output_dim):
		super(NeuralNetwork, self).__init__()
		self.linear = nn.Linear(input_dim, output_dim)
		self.linear.weight = torch.nn.Parameter(torch.zeros(input_dim, output_dim))
		self.linear.bias = torch.nn.Parameter(torch.ones(output_dim))
	def forward(self, input_array):
		output = self.linear(input_array)
		return output
if __name__ == '__main__':
	net = NeuralNetwork(4, 6)
	for param in net.parameters():
		print(param)

代码的结果如下所示:

 当神经网络的参数不是用Parameter函数参数化直接赋值给权重参数时,则会报错,具体的程序

import torch
import torch.nn as nn
class NeuralNetwork(nn.Module):
	def __init__(self, input_dim, output_dim):
		super(NeuralNetwork, self).__init__()
		self.linear = nn.Linear(input_dim, output_dim)
		self.linear.weight = torch.zeros(input_dim, output_dim)
		self.linear.bias = torch.ones(output_dim)
	def forward(self, input_array):
		output = self.linear(input_array)
		return output
if __name__ == '__main__':
	net = NeuralNetwork(4, 6)
	for param in net.parameters():
		print(param)

代码运行报错结果如下所示:

以上就是pytorch中Parameter函数用法示例的详细内容,更多关于pytorch中Parameter函数的资料请关注我们其它相关文章!

(0)

相关推荐

  • Pytorch之parameters的使用

    1.预构建网络 class Net(nn.Module): def __init__(self): super(Net, self).__init__() # 1 input image channel, 6 output channels, 5*5 square convolution # kernel self.conv1 = nn.Conv2d(1, 6, 5) self.conv2 = nn.Conv2d(6, 16, 5) # an affine operation: y = Wx +

  • 看过就懂的java零拷贝及实现方式详解

    目录 前言 1.什么是零拷贝 2. 传统 IO 的执行流程 3. 零拷贝相关的知识点回顾 3.1 内核空间和用户空间 3.2 什么是用户态.内核态 3.3 什么是上下文切换 3.4 虚拟内存 3.5 DMA技术 4. 零拷贝实现的几种方式 4.1 mmap+write实现的零拷贝 4.2 sendfile实现的零拷贝 4.3 sendfile+DMA scatter/gather实现的零拷贝 5. java提供的零拷贝方式 5.1 Java NIO对mmap的支持 5.2 Java NIO对se

  • PyTorch里面的torch.nn.Parameter()详解

    在看过很多博客的时候发现了一个用法self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size)),首先可以把这个函数理解为类型转换函数,将一个不可训练的类型Tensor转换成可以训练的类型parameter并将这个parameter绑定到这个module里面(net.parameter()中就有这个绑定的parameter,所以在参数优化的时候可以进行优化的),所以经过类型转换这个self.v变成了模型的一部分,成为了模型中根据训练可以改动

  • pytorch: Parameter 的数据结构实例

    一般来说,pytorch 的Parameter是一个tensor,但是跟通常意义上的tensor有些不一样 1) 通常意义上的tensor 仅仅是数据 2) 而Parameter所对应的tensor 除了包含数据之外,还包含一个属性:requires_grad(=True/False) 在Parameter所对应的tensor中获取纯数据,可以通过以下操作: param_data = Parameter.data 测试代码: #-*-coding:utf-8-*- import torch im

  • pytorch中Parameter函数用法示例

    目录 用法介绍 代码介绍 用法介绍 pytorch中的Parameter函数可以对某个张量进行参数化.它可以将不可训练的张量转化为可训练的参数类型,同时将转化后的张量绑定到模型可训练参数的列表中,当更新模型的参数时一并将其更新. torch.nn.parameter.Parameter data (Tensor):表示需要参数化的张量 requires_grad (bool, optional):表示是否该张量是否需要梯度,默认值为True 代码介绍  pytorch中的Parameter函数具

  • pytorch中permute()函数用法补充说明(矩阵维度变化过程)

    目录 一.前言 二.举例解释 1.permute(0,1,2) 2.permute(0,1,2) ⇒ permute(0,2,1) 3.permute(0,2,1) ⇒ permute(1,0,2) 4.permute(1,0,2) ⇒ permute(0,2,1) 三.写在最后 一.前言 之前写了篇torch中permute()函数用法文章,在详细的说一下permute函数里维度变化的详细过程 非常感谢@m0_46225327对本文案例更加细节补充 注意: 本文是这篇torch中permute

  • pytorch中permute()函数用法实例详解

    目录 前言 三维情况 变化一:不改变任何参数 变化二:1与2交换 变化三:0与1交换 变化四:0与2交换 变化五:0与1交换,1与2交换 变化六:0与1交换,0与2交换 总结 前言 本文只讨论二维三维中的permute用法 最近的Attention学习中的一个permute函数让我不理解 这个光说太抽象 我就结合代码与图片解释一下 首先创建一个三维数组小实例 import torch x = torch.linspace(1, 30, steps=30).view(3,2,5) # 设置一个三维

  • PyTorch中topk函数的用法详解

    听名字就知道这个函数是用来求tensor中某个dim的前k大或者前k小的值以及对应的index. 用法 torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor) input:一个tensor数据 k:指明是得到前k个数据以及其index dim: 指定在哪个维度上排序, 默认是最后一个维度 largest:如果为True,按照大到小排序: 如果为False,按照小到大排序

  • pytorch 中pad函数toch.nn.functional.pad()的用法

    padding操作是给图像外围加像素点. 为了实际说明操作过程,这里我们使用一张实际的图片来做一下处理. 这张图片是大小是(256,256),使用pad来给它加上一个黑色的边框.具体代码如下: import torch.nn,functional as F import torch from PIL import Image im=Image.open("heibai.jpg",'r') X=torch.Tensor(np.asarray(im)) print("shape:

  • pytorch中index_select()的用法详解

    pytorch中index_select()的用法 index_select(input, dim, index) 功能:在指定的维度dim上选取数据,不如选取某些行,列 参数介绍 第一个参数input是要索引查找的对象 第二个参数dim是要查找的维度,因为通常情况下我们使用的都是二维张量,所以可以简单的记忆: 0代表行,1代表列 第三个参数index是你要索引的序列,它是一个tensor对象 刚开始学习pytorch,遇到了index_select(),一开始不太明白几个参数的意思,后来查了一

  • C++中memset函数用法详解

    本文实例讲述了C++中memset函数用法.分享给大家供大家参考,具体如下: 功 能: 将s所指向的某一块内存中的每个字节的内容全部设置为ch指定的ASCII值,块的大小由第三个参数指定,这个函数通常为新申请的内存做初始化工作 用 法: void memset(void *s, char ch, unsigned n); 程序示例: #include <string.h> #include <stdio.h> #include <memory.h> int main(v

  • python中hashlib模块用法示例

    我们以前介绍过一篇Python加密的文章:Python 加密的实例详解.今天我们看看python中hashlib模块用法示例,具体如下. hashlib hashlib主要提供字符加密功能,将md5和sha模块整合到了一起,支持md5,sha1, sha224, sha256, sha384, sha512等算法 具体应用 #!/usr/bin/env python # -*- coding: UTF-8 -*- #pyversion:python3.5 #owner:fuzj import h

  • PHP中list()函数用法实例简析

    本文实例讲述了PHP中list()函数用法.分享给大家供大家参考,具体如下: PHP中的list() 函数用于在一次操作中给一组变量赋值. 注意:这里的数组变量只能为数字索引的数组,且假定数字索引从 0 开始. list()函数定义如下: list(var1,var2...) 参数说明: var1      必需.第一个需要赋值的变量. var2,...  可选.更多需要赋值的变量. 示例代码如下: <?php //$arr=array('name'=>'Tom','pwd'=>'123

  • javaScript中slice函数用法实例分析

    本文实例讲述了javaScript中slice函数用法.分享给大家供大家参考.具体分析如下: javaScript 中的 slice 函数,对于array对象的slice函数,返回一个数组的一段.(仍为数组) arrayObj.slice(start, [end]) 参数: arrayObj,必选项.一个 Array 对象.  start,必选项.arrayObj 中所指定的部分的开始元素是从零开始计算的下标.  end,可选项.arrayObj 中所指定的部分的结束元素是从零开始计算的下标.

随机推荐