pytorch 共享参数的示例

在很多神经网络中,往往会出现多个层共享一个权重的情况,pytorch可以快速地处理权重共享问题。

例子1:

class ConvNet(nn.Module):
  def __init__(self):
    super(ConvNet, self).__init__()
    self.conv_weight = nn.Parameter(torch.randn(3, 3, 5, 5))

  def forward(self, x):
    x = nn.functional.conv2d(x, self.conv_weight, bias=None, stride=1, padding=2, dilation=1, groups=1)
    x = nn.functional.conv2d(x, self.conv_weight.transpose(2, 3).contiguous(), bias=None, stride=1, padding=0, dilation=1,
                 groups=1)
    return x

上边这段程序定义了两个卷积层,这两个卷积层共享一个权重conv_weight,第一个卷积层的权重是conv_weight本身,第二个卷积层是conv_weight的转置。注意在gpu上运行时,transpose()后边必须加上.contiguous()使转置操作连续化,否则会报错。

例子2:

class LinearNet(nn.Module):
  def __init__(self):
    super(LinearNet, self).__init__()
    self.linear_weight = nn.Parameter(torch.randn(3, 3))

  def forward(self, x):
    x = nn.functional.linear(x, self.linear_weight)
    x = nn.functional.linear(x, self.linear_weight.t())

    return x

这个网络实现了一个双层感知器,权重同样是一个parameter的本身及其转置。

例子3:

class LinearNet2(nn.Module):
  def __init__(self):
    super(LinearNet2, self).__init__()
    self.w = nn.Parameter(torch.FloatTensor([[1.1,0,0], [0,1,0], [0,0,1]]))

  def forward(self, x):
    x = x.mm(self.w)
    x = x.mm(self.w.t())
    return x

这个方法直接用mm函数将x与w相乘,与上边的网络效果相同。

以上这篇pytorch 共享参数的示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • python PyTorch参数初始化和Finetune

    前言 这篇文章算是论坛PyTorch Forums关于参数初始化和finetune的总结,也是我在写代码中用的算是"最佳实践"吧.最后希望大家没事多逛逛论坛,有很多高质量的回答. 参数初始化 参数的初始化其实就是对参数赋值.而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了data,grad等借口,这就意味着我们可以直接对这些参数进行操作赋值了.这就是PyTorch简洁高效所在. 所以我们可以进行如下操作进行初始化,当然其实有其他的方法,但是这种方法

  • pytorch 共享参数的示例

    在很多神经网络中,往往会出现多个层共享一个权重的情况,pytorch可以快速地处理权重共享问题. 例子1: class ConvNet(nn.Module): def __init__(self): super(ConvNet, self).__init__() self.conv_weight = nn.Parameter(torch.randn(3, 3, 5, 5)) def forward(self, x): x = nn.functional.conv2d(x, self.conv_w

  • Laravel基础_关于view共享数据的示例讲解

    1.所有视图共享数据(share) 当所有视图都需要同一个数据时,使用视图工厂的share方法. 全局帮助函数view,如果传入参数,则返回Illuminate\View\View实例,不传入参数则返回Illuminate\View\Factory实例.所以我们可以通过在服务提供者(app\Providers\AppServiceProvider.php)的boot方法中使用如下方式实现视图间共享数据: /** * Bootstrap any application services. * *

  • pytorch LayerNorm参数的用法及计算过程

    说明 LayerNorm中不会像BatchNorm那样跟踪统计全局的均值方差,因此train()和eval()对LayerNorm没有影响. LayerNorm参数 torch.nn.LayerNorm( normalized_shape: Union[int, List[int], torch.Size], eps: float = 1e-05, elementwise_affine: bool = True) normalized_shape 如果传入整数,比如4,则被看做只有一个整数的li

  • 人工智能学习Pytorch教程Tensor基本操作示例详解

    目录 一.tensor的创建 1.使用tensor 2.使用Tensor 3.随机初始化 4.其他数据生成 ①torch.full ②torch.arange ③linspace和logspace ④ones, zeros, eye ⑤torch.randperm 二.tensor的索引与切片 1.索引与切片使用方法 ①index_select ②... ③mask 三.tensor维度的变换 1.维度变换 ①torch.view ②squeeze/unsqueeze ③expand,repea

  • 人工智能学习Pytorch梯度下降优化示例详解

    目录 一.激活函数 1.Sigmoid函数 2.Tanh函数 3.ReLU函数 二.损失函数及求导 1.autograd.grad 2.loss.backward() 3.softmax及其求导 三.链式法则 1.单层感知机梯度 2. 多输出感知机梯度 3. 中间有隐藏层的求导 4.多层感知机的反向传播 四.优化举例 一.激活函数 1.Sigmoid函数 函数图像以及表达式如下: 通过该函数,可以将输入的负无穷到正无穷的输入压缩到0-1之间.在x=0的时候,输出0.5 通过PyTorch实现方式

  • Android Activity共享元素动画示例解析

    目录 正文 TransitionManager介绍 Scene(场景) 生成场景 Transition(过渡) OverlayView和ViewGroupOverlay GhostView Activity的共享元素源码分析 我们先以ActivityA打开ActivityB为例 ActivityB返回ActivityA SharedElementCallback回调总结 正文 所谓Activity共享元素动画,就是从ActivityA跳转到ActivityB 通过控制某些元素(View)从Act

  • 基于Pytorch实现分类器的示例详解

    目录 Softmax分类器 定义 训练 测试 感知机分类器 定义 训练 测试 本文实现两个分类器: softmax分类器和感知机分类器 Softmax分类器 Softmax分类是一种常用的多类别分类算法,它可以将输入数据映射到一个概率分布上.Softmax分类首先将输入数据通过线性变换得到一个向量,然后将向量中的每个元素进行指数函数运算,最后将指数运算结果归一化得到一个概率分布.这个概率分布可以被解释为每个类别的概率估计. 定义 定义一个softmax分类器类: class SoftmaxCla

  • pytorch 固定部分参数训练的方法

    需要自己过滤 optimizer.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3) 另外,如果是Variable,则可以初始化时指定 j = Variable(torch.randn(5,5), requires_grad=True) 但是如果是 m = nn.Linear(10,10) 是没有requires_grad传入的 m.requires_grad也没有 需要 for i in m.parameter

  • pytorch 自定义参数不更新方式

    nn.Module中定义参数:不需要加cuda,可以求导,反向传播 class BiFPN(nn.Module): def __init__(self, fpn_sizes): self.w1 = nn.Parameter(torch.rand(1)) print("no---------------------------------------------------",self.w1.data, self.w1.grad) 下面这个例子说明中间变量可能没有梯度,但是最终变量有梯度

  • python调用jenkinsAPI构建jenkins,并传递参数的示例

    安装jenkins 安装jenkins很简单,可以用多种方式安装,这里知道的有: 在官网下载rpm包,手动安装,最费事 centos系统通过yum安装,ubuntu通过apt-get安装(不推荐,因为很多东西都使用了默认的) 直接下载官网上的war包 我这里直接用的下载war包 遇到的坑 在安装之前,公司的服务器上已经有一个版本的jekins在运行了,所有参数都已经被设置过了,所以,重新安装的版本,虽然文件夹,用户都和以前的版本不一样,但是每次jenkins页面都是直接跳转上个版本的,并不会进入

随机推荐