pytorch 一行代码查看网络参数总量的实现
大家还是直接看代码吧~
netG = Generator() print('# generator parameters:', sum(param.numel() for param in netG.parameters())) netD = Discriminator() print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))
补充:PyTorch查看网络模型的参数量PARAMS和FLOPS等
在PyTorch中,可以使用torchstat这个库来查看网络模型的一些信息,包括总的参数量params、MAdd、显卡内存占用量和FLOPs等。
示例代码如下:
from torchstat import stat from torchvision.models import resnet50, resnet101, resnet152, resnext101_32x8d model = resnet50() stat(model, (3, 224, 224))
打印信息如下:
以上为个人经验,希望能给大家一个参考,也希望大家多多支持我们。如有错误或未考虑完全的地方,望不吝赐教。
相关推荐
-
pytorch 网络参数 weight bias 初始化详解
权重初始化对于训练神经网络至关重要,好的初始化权重可以有效的避免梯度消失等问题的发生. 在pytorch的使用过程中有几种权重初始化的方法供大家参考. 注意:第一种方法不推荐.尽量使用后两种方法. # not recommend def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: m.weight.data.normal_(0.0, 0.02) elif classname.
-
pytorch 实现查看网络中的参数
可以通过model.state_dict()或者model.named_parameters()函数查看现在的全部可训练参数(包括通过继承得到的父类中的参数) 可示例代码如下: params = list(model.named_parameters()) (name, param) = params[28] print(name) print(param.grad) print('-------------------------------------------------') (name
-
关于pytorch中网络loss传播和参数更新的理解
相比于2018年,在ICLR2019提交论文中,提及不同框架的论文数量发生了极大变化,网友发现,提及tensorflow的论文数量从2018年的228篇略微提升到了266篇,keras从42提升到56,但是pytorch的数量从87篇提升到了252篇. TensorFlow: 228--->266 Keras: 42--->56 Pytorch: 87--->252 在使用pytorch中,自己有一些思考,如下: 1. loss计算和反向传播 import torch.nn as nn
-
pytorch 实现在一个优化器中设置多个网络参数的例子
我就废话不多说了,直接上代码吧! 其实也不难,使用tertools.chain将参数链接起来即可 import itertools ... self.optimizer = optim.Adam(itertools.chain(self.encoder.parameters(), self.decoder.parameters()), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) ... 以上这篇pytorch 实现在一个优化器中设置多个网络参数的
-
pytorch 求网络模型参数实例
用pytorch训练一个神经网络时,我们通常会很关心模型的参数总量.下面分别介绍来两种方法求模型参数 一 .求得每一层的模型参数,然后自然的可以计算出总的参数. 1.先初始化一个网络模型model 比如我这里是 model=cliqueNet(里面是些初始化的参数) 2.调用model的Parameters类获取参数列表 一个典型的操作就是将参数列表传入优化器里.如下 optimizer = optim.Adam(model.parameters(), lr=opt.lr) 言归正传,继续回到参
-
pytorch查看网络参数显存占用量等操作
1.使用torchstat pip install torchstat from torchstat import stat import torchvision.models as models model = models.resnet152() stat(model, (3, 224, 224)) 关于stat函数的参数,第一个应该是模型,第二个则是输入尺寸,3为通道数.我没有调研该函数的详细参数,也不知道为什么使用的时候并不提示相应的参数. 2.使用torchsummary pip in
-
pytorch 一行代码查看网络参数总量的实现
大家还是直接看代码吧~ netG = Generator() print('# generator parameters:', sum(param.numel() for param in netG.parameters())) netD = Discriminator() print('# discriminator parameters:', sum(param.numel() for param in netD.parameters())) 补充:PyTorch查看网络模型的参数量PARA
-
Pytorch反向求导更新网络参数的方法
方法一:手动计算变量的梯度,然后更新梯度 import torch from torch.autograd import Variable # 定义参数 w1 = Variable(torch.FloatTensor([1,2,3]),requires_grad = True) # 定义输出 d = torch.mean(w1) # 反向求导 d.backward() # 定义学习率等参数 lr = 0.001 # 手动更新参数 w1.data.zero_() # BP求导更新参数之前,需先对导
-
在pytorch中如何查看模型model参数parameters
目录 pytorch查看模型model参数parameters pytorch查看模型参数总结 1:DNN_printer 2:parameters 3:get_model_complexity_info() 4:torchstat pytorch查看模型model参数parameters 示例1:pytorch自带的faster r-cnn模型 import torch import torchvision model = torchvision.models.detection.faster
-
Okhttp、Retrofit进度获取的方法(一行代码搞定)
起因 对于广大Android开发者来说,最近用的最多的网络库,莫过于Okhttp啦(Retrofit依赖Okhttp). Okhttp不像SDK内置的HttpUrlConnection一样,可以明确的获取数据读写的过程,我们需要执行一些操作. 介绍 Retrofit依赖Okhttp.Okhttp依赖于Okio.那么Okio又是什么鬼?别急,看官方介绍: Okio is a library that complements java.io and java.nio to make it much
-
PyTorch 编写代码遇到的问题及解决方案
PyTorch编写代码遇到的问题 错误提示:no module named xxx xxx为自定义文件夹的名字 因为搜索不到,所以将当前路径加入到包的搜索目录 解决方法: import sys sys.path.append('..') #将上层目录加入到搜索路径中 sys.path.append('/home/xxx') # 绝对路径 import os sys.path.append(os.getcwd()) # #将当前工作路径加入到搜索路径中 还可以在当前终端的命令行设置 export
-
pytorch加载自定义网络权重的实现
在将自定义的网络权重加载到网络中时,报错: AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead. 我们一步一步分析. 模型网络权重保存额代码是:torch.sa
随机推荐
- 正则表达式环视概念与用法分析
- Java 正则表达式入门详解(基础进阶)
- Powershell中显示隐藏文件的方法
- linux touch,chattr指令详解及用法
- python算法学习之桶排序算法实例(分块排序)
- ajaxpro.dll 控件实现异步刷新页面
- Android中RecyclerView嵌套滑动冲突解决的代码片段
- C++中实现把表的数据导出到EXCEL并打印实例代码
- IE下写xml文件的两种方式(fso/saveAs)
- Java UrlRewriter伪静态技术运用深入分析
- linux下mysql 5.7.16 免安装版本图文教程
- php使用qr生成二维码的示例分享
- 整理Javascript基础入门学习笔记
- 获取table中的rowIndex和cellIndex
- Tomcat日志文件定时清理备份的脚本
- 详解java中Reference的实现与相应的执行过程
- 轻松实现Android语音识别功能
- Android超实用的Toast提示框优化分享
- Android使用Sensor感应器获取用户移动方向(指南针原理)
- Hibernate的Session_flush与隔离级别代码详解