pytorch获取模型某一层参数名及参数值方式
1、Motivation:
I wanna modify the value of some param;
I wanna check the value of some param.
The needed function:
2、state_dict() #generator type
model.modules()#generator type
named_parameters()#OrderDict type
from torch import nn import torch #creat a simple model model = nn.Sequential( nn.Conv3d(1,16,kernel_size=1), nn.Conv3d(16,2,kernel_size=1))#tend to print the W of this layer input = torch.randn([1,1,16,256,256]) if torch.cuda.is_available(): print('cuda is avaliable') model.cuda() input = input.cuda() #打印某一层的参数名 for name in model.state_dict(): print(name) #Then I konw that the name of target layer is '1.weight' #schemem1(recommended) print(model.state_dict()['1.weight']) #scheme2 params = list(model.named_parameters())#get the index by debuging print(params[2][0])#name print(params[2][1].data)#data #scheme3 params = {}#change the tpye of 'generator' into dict for name,param in model.named_parameters(): params[name] = param.detach().cpu().numpy() print(params['0.weight']) #scheme4 for layer in model.modules(): if(isinstance(layer,nn.Conv3d)): print(layer.weight) #打印每一层的参数名和参数值 #schemem1(recommended) for name,param in model.named_parameters(): print(name,param) #scheme2 for name in model.state_dict(): print(name) print(model.state_dict()[name])
以上这篇pytorch获取模型某一层参数名及参数值方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。
相关推荐
-
pytorch构建网络模型的4种方法
利用pytorch来构建网络模型有很多种方法,以下简单列出其中的四种. 假设构建一个网络模型如下: 卷积层-->Relu层-->池化层-->全连接层-->Relu层-->全连接层 首先导入几种方法用到的包: import torch import torch.nn.functional as F from collections import OrderedDict 第一种方法 # Method 1 --------------------------------------
-
pytorch中获取模型input/output shape实例
Pytorch官方目前无法像tensorflow, caffe那样直接给出shape信息,详见 https://github.com/pytorch/pytorch/pull/3043 以下代码算一种workaround.由于CNN, RNN等模块实现不一样,添加其他模块支持可能需要改代码. 例如RNN中bias是bool类型,其权重也不是存于weight属性中,不过我们只关注shape够用了. 该方法必须构造一个输入调用forward后(model(x)调用)才可获取shape #coding
-
pytorch 加载(.pth)格式的模型实例
有一些非常流行的网络如 resnet.squeezenet.densenet等在pytorch里面都有,包括网络结构和训练好的模型. pytorch自带模型网址:https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-models/ 按官网加载预训练好的模型: import torchvision.models as models # pretrained=True就可以使用预训练的模型 resnet18 = mod
-
基于pytorch的保存和加载模型参数的方法
当我们花费大量的精力训练完网络,下次预测数据时不想再(有时也不必再)训练一次时,这时候torch.save(),torch.load()就要登场了. 保存和加载模型参数有两种方式: 方式一: torch.save(net.state_dict(),path): 功能:保存训练完的网络的各层参数(即weights和bias) 其中:net.state_dict()获取各层参数,path是文件存放路径(通常保存文件格式为.pt或.pth) net2.load_state_dict(torch.loa
-
pytorch 求网络模型参数实例
用pytorch训练一个神经网络时,我们通常会很关心模型的参数总量.下面分别介绍来两种方法求模型参数 一 .求得每一层的模型参数,然后自然的可以计算出总的参数. 1.先初始化一个网络模型model 比如我这里是 model=cliqueNet(里面是些初始化的参数) 2.调用model的Parameters类获取参数列表 一个典型的操作就是将参数列表传入优化器里.如下 optimizer = optim.Adam(model.parameters(), lr=opt.lr) 言归正传,继续回到参
-
Pytorch之保存读取模型实例
pytorch保存数据 pytorch保存数据的格式为.t7文件或者.pth文件,t7文件是沿用torch7中读取模型权重的方式.而pth文件是python中存储文件的常用格式.而在keras中则是使用.h5文件. # 保存模型示例代码 print('===> Saving models...') state = { 'state': model.state_dict(), 'epoch': epoch # 将epoch一并保存 } if not os.path.isdir('checkpoin
-
pytorch获取模型某一层参数名及参数值方式
1.Motivation: I wanna modify the value of some param; I wanna check the value of some param. The needed function: 2.state_dict() #generator type model.modules()#generator type named_parameters()#OrderDict type from torch import nn import torch #creat
-
Java获取代码中方法参数名信息的方法
前言 大家都知道随着java8的使用,在相应的方法签名中增加了新的对象Parameter,用于表示特定的参数信息,通过它的getName可以获取相应的参数名.即像在代码中编写的,如命名为username,那么在前台进行传参时,即不需要再编写如@Parameter("username")类的注解,而直接就能进行按名映射. 如下的代码参考所示: public class T { private interface T2 { void method(String username, Stri
-
pytorch 实现模型不同层设置不同的学习率方式
在目标检测的模型训练中, 我们通常都会有一个特征提取网络backbone, 例如YOLO使用的darknet SSD使用的VGG-16. 为了达到比较好的训练效果, 往往会加载预训练的backbone模型参数, 然后在此基础上训练检测网络, 并对backbone进行微调, 这时候就需要为backbone设置一个较小的lr. class net(torch.nn.Module): def __init__(self): super(net, self).__init__() # backbone
-
Spring Aop 如何获取参数名参数值
前言: 有时候我们在用Spring Aop面向切面编程,需要获取连接点(JoinPoint)方法参数名.参数值. 环境: Mac OSX Intellij IDEA Spring Boot 2x Jdk 1.8x Code: package com.example.aopdemo.aop; import lombok.extern.slf4j.Slf4j; import org.aspectj.lang.ProceedingJoinPoint; import org.aspectj.lang.a
-
Shell脚本通过参数名传递参数的实现代码
平常在写shell脚本都是用$1,$2-这种方式来接收参数,然而这种接收参数的方式不但容易忘记且不易于理解和维护.Linux常用的命令都可指定参数名和参数值,然而我们怎样才能给自己的shell脚本也采用参数名和参数值这样的方式来获取参数值呢?而不是通过$1,$2这种方式进行获取.下面的例子定义了短参数名和长参数名两种获取参数值的方式.其实是根据getopt提供的特性进行整理而来. #!/bin/bash while getopts i:o:p:s:t: OPT; do case ${OPT} i
-
vue获取参数的几种方式总结
目录 路由基础 1.SPA与路由 1.1 SPA介绍 1.2 路由介绍 1.3前端路由原理 2.vue-router基础用法 2.1 下载VueRouter 2.2 一般使用过程 2.3 修改路由模式 2.4 为当前导航添加样式 2.5 路由重定向 2.6 router-link标签中的tag属性 3.前端路由嵌套 4.路由之间的传参 4.1.query传参 4.2.params方式传参 4.3 获取当前路由路径 4.4 实现图书详情 5.编程式导航 5.1 $router.push添加 / 跳
-
JavaScript函数参数使用带参数名的方式赋值传入的方法
本文实例讲述了JavaScript函数参数使用带参数名的方式赋值传入的方法.分享给大家供大家参考.具体分析如下: 这里其实就是在给函数传递参数的时候,可以使用 参数名:参数值的方式传递,这样不会传递错.不过下面的代码是通过字典来实现的,不像python原封就支持这样的方法 function foo({ name:name, project:project}) { Print( project ); Print( name ); } 调用方法 foo({ name:'soubok', projec
-
JavaScript获取function所有参数名的方法
我写了一个 JavaScript函数来解析函数的参数名称, 代码如下: function getArgs(func) { // 先用正则匹配,取得符合参数模式的字符串. // 第一个分组是这个: ([^)]*) 非右括号的任意字符 var args = func.toString().match(/function\s.*?\(([^)]*)\)/)[1]; // 用逗号来分隔参数(arguments string). return args.split(",").map(functi
-
tensorflow 获取模型所有参数总和数量的方法
实例如下所示: from functools import reduce from operator import mul def get_num_params(): num_params = 0 for variable in tf.trainable_variables(): shape = variable.get_shape() num_params += reduce(mul, [dim.value for dim in shape], 1) return num_params 以上这
-
画pytorch模型图,以及参数计算的方法
刚入pytorch的坑,代码还没看太懂.之前用keras用习惯了,第一次使用pytorch还有些不适应,希望广大老司机多多指教. 首先说说,我们如何可视化模型.在keras中就一句话,keras.summary(),或者plot_model(),就可以把模型展现的淋漓尽致. 但是pytorch中好像没有这样一个api让我们直观的看到模型的样子.但是有网友提供了一段代码,可以把模型画出来,对我来说简直就是如有神助啊. 话不多说,上代码吧. import torch from torch.autog
随机推荐
- NodeJS框架Express的模板视图机制分析
- python保存字符串到文件的方法
- 借助RubyGnome2库进行GTK下的Ruby GUI编程的基本方法
- PowerShell中使用正则表达式筛选数组实例
- jQuery 3.0 的变化及使用方法
- php打乱数组二维数组多维数组的简单实例
- PHP使用数组实现矩阵数学运算的方法示例
- VML绘图板②脚本--VMLgraph.js、XMLtool.js
- PHP数学运算与数据处理实例分析
- 解析PHP计算页面执行时间的实现代码
- oracle中left join和right join的区别浅谈
- vue2.0中click点击当前li实现动态切换class
- 控制台报错object is not a function的解决方法
- Java项目安全处理方法
- android使用PullToRefresh框架实现ListView下拉刷新上拉加载更多
- Django自定义manage命令实例代码
- WCF中使用nettcp协议进行通讯的方法
- Python+OpenCv制作证件图片生成器的操作方法
- Java使用I/O流读取文件内容的方法详解
- python中yield的用法详解——最简单,最清晰的解释