pytorch 求网络模型参数实例

用pytorch训练一个神经网络时,我们通常会很关心模型的参数总量。下面分别介绍来两种方法求模型参数

一 .求得每一层的模型参数,然后自然的可以计算出总的参数。

1.先初始化一个网络模型model

比如我这里是 model=cliqueNet(里面是些初始化的参数)

2.调用model的Parameters类获取参数列表

一个典型的操作就是将参数列表传入优化器里。如下

 optimizer = optim.Adam(model.parameters(), lr=opt.lr)

言归正传,继续回到参数里面,参数在网络里面就是variable,下面分别求每层的尺寸大小和个数。

函数get_number_of_param( ) 里面的参数就是刚才第一步初始化的model

def get_number_of_param(model):
  """get the number of param for every element"""
  count = 0
  for param in model.parameters():
    param_size = param.size()
    count_of_one_param = 1
    for dis in param_size:
      count_of_one_param *= dis
    print(param.size(), count_of_one_param)
    count += count_of_one_param
    print(count)
  print('total number of the model is %d'%count)

再来看看结果:

torch.Size([64, 1, 3, 3]) 576
576
torch.Size([64]) 64
640
torch.Size([6, 36, 64, 3, 3]) 124416
125056
torch.Size([30, 36, 36, 3, 3]) 349920
474976
torch.Size([12, 36]) 432
475408
torch.Size([6, 36, 216, 3, 3]) 419904
895312
torch.Size([30, 36, 36, 3, 3]) 349920
1245232
torch.Size([12, 36]) 432
1245664
torch.Size([6, 36, 216, 3, 3]) 419904
1665568
torch.Size([30, 36, 36, 3, 3]) 349920
2015488
torch.Size([12, 36]) 432
2015920
torch.Size([6, 36, 216, 3, 3]) 419904
2435824
torch.Size([30, 36, 36, 3, 3]) 349920
2785744
torch.Size([12, 36]) 432
2786176
torch.Size([216, 216, 1, 1]) 46656
2832832
torch.Size([216]) 216
2833048
torch.Size([108, 216]) 23328
2856376
torch.Size([108]) 108
2856484
torch.Size([216, 108]) 23328
2879812
torch.Size([216]) 216
2880028
torch.Size([216, 216, 1, 1]) 46656
2926684
torch.Size([216]) 216
2926900
torch.Size([108, 216]) 23328
2950228
torch.Size([108]) 108
2950336
torch.Size([216, 108]) 23328
2973664
torch.Size([216]) 216
2973880
torch.Size([216, 216, 1, 1]) 46656
3020536
torch.Size([216]) 216
3020752
torch.Size([108, 216]) 23328
3044080
torch.Size([108]) 108
3044188
torch.Size([216, 108]) 23328
3067516
torch.Size([216]) 216
3067732
torch.Size([140, 280, 1, 1]) 39200
3106932
torch.Size([140]) 140
3107072
torch.Size([216, 432, 1, 1]) 93312
3200384
torch.Size([216]) 216
3200600
torch.Size([216, 432, 1, 1]) 93312
3293912
torch.Size([216]) 216
3294128
torch.Size([9, 572, 3, 3]) 46332
3340460
torch.Size([9]) 9
3340469
total number of the model is 3340469

可以通过计算验证一下,发现参数与网络是一致的。

二:一行代码就可以搞定参数总个数问题

2.1 先来看看torch.tensor.numel( )这个函数的功能就是求tensor中的元素个数,在网络里面每层参数就是多维数组组成的tensor。

实际上就是求多维数组的元素个数。看代码。

print('cliqueNet parameters:', sum(param.numel() for param in model.parameters()))

当然上面代码中的model还是上面初始化的网络模型。

看看两种的计算结果

torch.Size([64, 1, 3, 3]) 576
576
torch.Size([64]) 64
640
torch.Size([6, 36, 64, 3, 3]) 124416
125056
torch.Size([30, 36, 36, 3, 3]) 349920
474976
torch.Size([12, 36]) 432
475408
torch.Size([6, 36, 216, 3, 3]) 419904
895312
torch.Size([30, 36, 36, 3, 3]) 349920
1245232
torch.Size([12, 36]) 432
1245664
torch.Size([6, 36, 216, 3, 3]) 419904
1665568
torch.Size([30, 36, 36, 3, 3]) 349920
2015488
torch.Size([12, 36]) 432
2015920
torch.Size([6, 36, 216, 3, 3]) 419904
2435824
torch.Size([30, 36, 36, 3, 3]) 349920
2785744
torch.Size([12, 36]) 432
2786176
torch.Size([216, 216, 1, 1]) 46656
2832832
torch.Size([216]) 216
2833048
torch.Size([108, 216]) 23328
2856376
torch.Size([108]) 108
2856484
torch.Size([216, 108]) 23328
2879812
torch.Size([216]) 216
2880028
torch.Size([216, 216, 1, 1]) 46656
2926684
torch.Size([216]) 216
2926900
torch.Size([108, 216]) 23328
2950228
torch.Size([108]) 108
2950336
torch.Size([216, 108]) 23328
2973664
torch.Size([216]) 216
2973880
torch.Size([216, 216, 1, 1]) 46656
3020536
torch.Size([216]) 216
3020752
torch.Size([108, 216]) 23328
3044080
torch.Size([108]) 108
3044188
torch.Size([216, 108]) 23328
3067516
torch.Size([216]) 216
3067732
torch.Size([140, 280, 1, 1]) 39200
3106932
torch.Size([140]) 140
3107072
torch.Size([216, 432, 1, 1]) 93312
3200384
torch.Size([216]) 216
3200600
torch.Size([216, 432, 1, 1]) 93312
3293912
torch.Size([216]) 216
3294128
torch.Size([9, 572, 3, 3]) 46332
3340460
torch.Size([9]) 9
3340469
total number of the model is 3340469
cliqueNet parameters: 3340469

可以看出两种计算出来的是一模一样的。

以上这篇pytorch 求网络模型参数实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • pytorch构建网络模型的4种方法

    利用pytorch来构建网络模型有很多种方法,以下简单列出其中的四种. 假设构建一个网络模型如下: 卷积层-->Relu层-->池化层-->全连接层-->Relu层-->全连接层 首先导入几种方法用到的包: import torch import torch.nn.functional as F from collections import OrderedDict 第一种方法 # Method 1 --------------------------------------

  • pytorch获取模型某一层参数名及参数值方式

    1.Motivation: I wanna modify the value of some param; I wanna check the value of some param. The needed function: 2.state_dict() #generator type model.modules()#generator type named_parameters()#OrderDict type from torch import nn import torch #creat

  • pytorch中获取模型input/output shape实例

    Pytorch官方目前无法像tensorflow, caffe那样直接给出shape信息,详见 https://github.com/pytorch/pytorch/pull/3043 以下代码算一种workaround.由于CNN, RNN等模块实现不一样,添加其他模块支持可能需要改代码. 例如RNN中bias是bool类型,其权重也不是存于weight属性中,不过我们只关注shape够用了. 该方法必须构造一个输入调用forward后(model(x)调用)才可获取shape #coding

  • Pytorch之保存读取模型实例

    pytorch保存数据 pytorch保存数据的格式为.t7文件或者.pth文件,t7文件是沿用torch7中读取模型权重的方式.而pth文件是python中存储文件的常用格式.而在keras中则是使用.h5文件. # 保存模型示例代码 print('===> Saving models...') state = { 'state': model.state_dict(), 'epoch': epoch # 将epoch一并保存 } if not os.path.isdir('checkpoin

  • pytorch 加载(.pth)格式的模型实例

    有一些非常流行的网络如 resnet.squeezenet.densenet等在pytorch里面都有,包括网络结构和训练好的模型. pytorch自带模型网址:https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-models/ 按官网加载预训练好的模型: import torchvision.models as models # pretrained=True就可以使用预训练的模型 resnet18 = mod

  • 基于pytorch的保存和加载模型参数的方法

    当我们花费大量的精力训练完网络,下次预测数据时不想再(有时也不必再)训练一次时,这时候torch.save(),torch.load()就要登场了. 保存和加载模型参数有两种方式: 方式一: torch.save(net.state_dict(),path): 功能:保存训练完的网络的各层参数(即weights和bias) 其中:net.state_dict()获取各层参数,path是文件存放路径(通常保存文件格式为.pt或.pth) net2.load_state_dict(torch.loa

  • pytorch 求网络模型参数实例

    用pytorch训练一个神经网络时,我们通常会很关心模型的参数总量.下面分别介绍来两种方法求模型参数 一 .求得每一层的模型参数,然后自然的可以计算出总的参数. 1.先初始化一个网络模型model 比如我这里是 model=cliqueNet(里面是些初始化的参数) 2.调用model的Parameters类获取参数列表 一个典型的操作就是将参数列表传入优化器里.如下 optimizer = optim.Adam(model.parameters(), lr=opt.lr) 言归正传,继续回到参

  • Pytorch加载部分预训练模型的参数实例

    前言 自从从深度学习框架caffe转到Pytorch之后,感觉Pytorch的优点妙不可言,各种设计简洁,方便研究网络结构修改,容易上手,比TensorFlow的臃肿好多了.对于深度学习的初学者,Pytorch值得推荐.今天主要主要谈谈Pytorch是如何加载预训练模型的参数以及代码的实现过程. 直接加载预选脸模型 如果我们使用的模型和预训练模型完全一样,那么我们就可以直接加载别人的模型,还有一种情况,我们在训练自己模型的过程中,突然中断了,但只要我们保存了之前的模型的参数也可以使用下面的代码直

  • pytorch forward两个参数实例

    以channel Attention Block为例子 class CAB(nn.Module): def __init__(self, in_channels, out_channels): super(CAB, self).__init__() self.global_pooling = nn.AdaptiveAvgPool2d(output_size=1) self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, st

  • pytorch打印网络结构的实例

    最简单的方法当然可以直接print(net),但是这样网络比较复杂的时候效果不太好,看着比较乱:以前使用caffe的时候有一个网站可以在线生成网络框图,tensorflow可以用tensor board,keras中可以用model.summary().或者plot_model().pytorch没有这样的API,但是可以用代码来完成. (1)安装环境:graphviz conda install -n pytorch python-graphviz 或: sudo apt-get instal

  • Python根据欧拉角求旋转矩阵的实例

    利用numpy和scipy,我们可以很容易根据欧拉角求出旋转矩阵,这里的旋转轴我们你理解成四元数里面的旋转轴 import numpy as np import scipy.linalg as linalg import math #参数分别是旋转轴和旋转弧度值 def rotate_mat(self, axis, radian): rot_matrix = linalg.expm(np.cross(np.eye(3), axis / linalg.norm(axis) * radian)) a

  • mysql数据存储过程参数实例详解

    MySQL 存储过程参数有三种类型:in.out.inout.它们各有什么作用和特点呢? 一.MySQL 存储过程参数(in) MySQL 存储过程 "in" 参数:跟 C 语言的函数参数的值传递类似, MySQL 存储过程内部可能会修改此参数,但对 in 类型参数的修改,对调用者(caller)来说是不可见的(not visible). drop procedure if exists pr_param_in; create procedure pr_param_in ( in id

  • 浅析JS获取url中的参数实例代码

    js获取url中的参数代码如下所示,代码简单易懂,附有注释,写的不好还请见谅! function UrlSearch() { var name, value; var str = location.href; //取得整个地址栏 var num = str.indexOf("?") str = str.substr(num + 1); //取得所有参数 stringvar.substr(start [, length ] var arr = str.split("&&

  • 读取xml文件中的配置参数实例

    paras.xml文件 <?xml version="1.0" encoding="UTF-8"?> <beans xmlns="http://www.springframework.org/schema/beans" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns:aop="http://www.springframework

  • Oracle的out参数实例详解

    Oracle的out参数实例详解 一 概念 1.一般来讲,存储过程和存储函数的区别在于存储函数可以有一个返回值:而存储过程没有返回值. 2.过程和函数都可以通过out指定一个或多个输出行.我们可以利用out参数,在过程和函数中实现返回多个值. 3.存储过程和存储函数都可以有out参数. 4.存储过程和存储函数都可以有多个out参数. 5.存储过程可以通过out参数来实现返回值. 6.如果只有一个返回值,用存储函数:否则,就用存储过程. 二 实例 --out参数:查询某个员工姓名月薪和职位 /*

  • DataTables添加额外的查询参数和删除columns等无用参数实例

    废话不多说,直接上代码 //1.定义全局变量 var iStart = 0, searchParams={}; //2.配置datatable的ajax配置项 "ajax": { "url": "/user/query", "type": "POST", //动态请求参数设置,会应用到每次请求 "data": function (d) { //删除多余请求参数 for(var key i

随机推荐