pytorch 实现打印模型的参数值
对于简单的网络
例如全连接层Linear
可以使用以下方法打印linear层:
fc = nn.Linear(3, 5) params = list(fc.named_parameters()) print(params.__len__()) print(params[0]) print(params[1])
输出如下:
由于Linear默认是偏置bias的,所有参数列表的长度是2。第一个存的是全连接矩阵,第二个存的是偏置。
对于稍微复杂的网络
例如MLP
mlp = nn.Sequential( nn.Dropout(p=0.3), nn.Linear(1024, 256), nn.Linear(256, 64), nn.Linear(64, 16), nn.Linear(16, 1) ) params = list(mlp.named_parameters()) print(params.__len__()) print(params[0]) print(params[1]) print(params[2]) print(params[3])
输出:
可以发现,堆叠起来的网络,参数是依次放置的。先是全连接的权重,然后偏置。然后是下一层网络的权重+偏置。依次进行下去。
这里有4层fc,4*2=8.所以一共有8个参数矩阵。
以上这篇pytorch 实现打印模型的参数值就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。
相关推荐
-
详解PyTorch手写数字识别(MNIST数据集)
MNIST 手写数字识别是一个比较简单的入门项目,相当于深度学习中的 Hello World,可以让我们快速了解构建神经网络的大致过程.虽然网上的案例比较多,但还是要自己实现一遍.代码采用 PyTorch 1.0 编写并运行. 导入相关库 import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, t
-
pytorch 数据集图片显示方法
图片显示 pytorch 载入的数据集是元组tuple 形式,里面包括了数据及标签(train_data,label),其中的train_data数据可以转换为torch.Tensor形式,方便后面计算使用. 同样给一些刚入门的同学在使用载入的数据显示图片的时候带来一些难以理解的地方,这里主要是将Tensor与numpy转换的过程,理解了这些就可以就行转换了 CIAFA10数据集 首先载入数据集,这里做了一些数据处理,包括图片尺寸.数据归一化等 import torch from torch.a
-
pytorch 自定义数据集加载方法
pytorch 官网给出的例子中都是使用了已经定义好的特殊数据集接口来加载数据,而且其使用的数据都是官方给出的数据.如果我们有自己收集的数据集,如何用来训练网络呢?此时需要我们自己定义好数据处理接口.幸运的是pytroch给出了一个数据集接口类(torch.utils.data.Dataset),可以方便我们继承并实现自己的数据集接口. torch.utils.data torch的这个文件包含了一些关于数据集处理的类. class torch.utils.data.Dataset: 一个抽象类
-
pytorch 批次遍历数据集打印数据的例子
我就废话不多说了,直接上代码吧! from os import listdir import os from time import time import torch.utils.data as data import torchvision.transforms as transforms from torch.utils.data import DataLoader def printProgressBar(iteration, total, prefix='', suffix='', d
-
PyTorch读取Cifar数据集并显示图片的实例讲解
首先了解一下需要的几个类所在的package from torchvision import transforms, datasets as ds from torch.utils.data import DataLoader import matplotlib.pyplot as plt import numpy as np #transform = transforms.Compose是把一系列图片操作组合起来,比如减去像素均值等. #DataLoader读入的数据类型是PIL.Image
-
pytorch 实现打印模型的参数值
对于简单的网络 例如全连接层Linear 可以使用以下方法打印linear层: fc = nn.Linear(3, 5) params = list(fc.named_parameters()) print(params.__len__()) print(params[0]) print(params[1]) 输出如下: 由于Linear默认是偏置bias的,所有参数列表的长度是2.第一个存的是全连接矩阵,第二个存的是偏置. 对于稍微复杂的网络 例如MLP mlp = nn.Sequential
-
PyTorch搭建多项式回归模型(三)
PyTorch基础入门三:PyTorch搭建多项式回归模型 1)理论简介 对于一般的线性回归模型,由于该函数拟合出来的是一条直线,所以精度欠佳,我们可以考虑多项式回归来拟合更多的模型.所谓多项式回归,其本质也是线性回归.也就是说,我们采取的方法是,提高每个属性的次数来增加维度数.比如,请看下面这样的例子: 如果我们想要拟合方程: 对于输入变量和输出值,我们只需要增加其平方项.三次方项系数即可.所以,我们可以设置如下参数方程: 可以看到,上述方程与线性回归方程并没有本质区别.所以我们可以采用线性回
-
深入理解Pytorch微调torchvision模型
目录 一.简介 二.导入相关包 三.数据输入 四.辅助函数 1.模型训练和验证 2.设置模型参数的'.requires_grad属性' 一.简介 在本小节,深入探讨如何对torchvision进行微调和特征提取.所有模型都已经预先在1000类的magenet数据集上训练完成. 本节将深入介绍如何使用几个现代的CNN架构,并将直观展示如何微调任意的PyTorch模型. 本节将执行两种类型的迁移学习: 微调:从预训练模型开始,更新我们新任务的所有模型参数,实质上是重新训练整个模型. 特征提取:从预训
-
pytorch中获取模型input/output shape实例
Pytorch官方目前无法像tensorflow, caffe那样直接给出shape信息,详见 https://github.com/pytorch/pytorch/pull/3043 以下代码算一种workaround.由于CNN, RNN等模块实现不一样,添加其他模块支持可能需要改代码. 例如RNN中bias是bool类型,其权重也不是存于weight属性中,不过我们只关注shape够用了. 该方法必须构造一个输入调用forward后(model(x)调用)才可获取shape #coding
-
pytorch构建多模型实例
pytorch构建双模型 第一部分:构建"se_resnet152","DPN92()"双模型 import numpy as np from functools import partial import torch from torch import nn import torch.nn.functional as F from torch.optim import SGD,Adam from torch.autograd import Variable fro
-
Pytorch实现将模型的所有参数的梯度清0
有两种方式直接把模型的参数梯度设成0: model.zero_grad() optimizer.zero_grad()#当optimizer=optim.Optimizer(model.parameters())时,两者等效 如果想要把某一Variable的梯度置为0,只需用以下语句: Variable.grad.data.zero_() 补充知识:PyTorch中在反向传播前为什么要手动将梯度清零?optimizer.zero_grad()的意义 optimizer.zero_grad()意思
-
PyTorch深度学习模型的保存和加载流程详解
一.模型参数的保存和加载 torch.save(module.state_dict(), path):使用module.state_dict()函数获取各层已经训练好的参数和缓冲区,然后将参数和缓冲区保存到path所指定的文件存放路径(常用文件格式为.pt..pth或.pkl). torch.nn.Module.load_state_dict(state_dict):从state_dict中加载参数和缓冲区到Module及其子类中 . torch.nn.Module.state_dict()函数
-
pytorch如何获得模型的计算量和参数量
方法1 自带 pytorch自带方法,计算模型参数总量 total = sum([param.nelement() for param in model.parameters()]) print("Number of parameter: %.2fM" % (total/1e6)) 或者 total = sum(p.numel() for p in model.parameters()) print("Total params: %.2fM" % (total/1e
-
pytorch 如何打印网络回传梯度
需求: 打印梯度,检查网络学习情况 net = your_network().cuda() def train(): ... outputs = net(inputs) loss = criterion(outputs, targets) loss.backward() for name, parms in net.named_parameters(): print('-->name:', name, '-->grad_requirs:',parms.requires_grad, \ ' --
-
Pytorch中求模型准确率的两种方法小结
方法一:直接在epoch过程中求取准确率 简介:此段代码是LeNet5中截取的. def train_model(model,train_loader): optimizer = torch.optim.Adam(model.parameters()) loss_func = nn.CrossEntropyLoss() EPOCHS = 5 for epoch in range(EPOCHS): correct = 0 for batch_idx,(X_batch,y_batch) in enu
随机推荐
- 封装ThinkPHP的一个文件上传方法实例
- 详解PHP用substr函数截取字符串中的某部分
- Server.CreateObject的调用失败拒绝对此对象的访问的解决方法
- 简析Java中的util.concurrent.Future接口
- Java获取当地的日出日落时间代码分享
- 苹果公司推出的新编程语言Swift简介和入门教程
- Java随机生成手机短信验证码的方法
- 在asp.NET 中使用SMTP发送邮件的实现代码
- 水晶报表图片不显示两种问题分析及解决方法
- Swift教程之控制流详解
- PHP动态地创建属性和方法, 对象的复制, 对象的比较,加载指定的文件,自动加载类文件,命名空间
- 文件上传类
- C语言WinSock学习笔记第1/2页
- Android 创建与解析XML(四)——详解Pull方式
- A10_DatePicker的对话框设置(使用OnDateSetListener监听器)
- Java实现计算一个月有多少天和多少周
- 没去过上海,不知道上海是这样的 上海,今夜请将我埋藏第1/3页
- shell脚本结合iptables防端口扫描的实现
- 浅析javascript中的事件代理
- javascript中href和replace的比较(详解)