从Pytorch模型pth文件中读取参数成numpy矩阵的操作
目的:
把训练好的pth模型参数提取出来,然后用其他方式部署到边缘设备。
Pytorch给了很方便的读取参数接口:
nn.Module.parameters()
直接看demo:
from torchvision.models.alexnet import alexnet model = alexnet(pretrained=True).eval().cuda() parameters = model.parameters() for p in parameters: numpy_para = p.detach().cpu().numpy() print(type(numpy_para)) print(numpy_para.shape)
上面得到的numpy_para就是numpy参数了~
Note:
model.parameters()是以一个生成器的形式迭代返回每一层的参数。所以用for循环读取到各层的参数,循环次数就表示层数。
而每一层的参数都是torch.nn.parameter.Parameter类型,是Tensor的子类,所以直接用tensor转numpy(即p.detach().cpu().numpy())的方法就可以直接转成numpy矩阵。
方便又好用,爆赞~
补充:pytorch训练好的.pth模型转换为.pt
将python训练好的.pth文件转为.pt
import torch import torchvision from unet import UNet model = UNet(3, 2)#自己定义的网络模型 model.load_state_dict(torch.load("best_weights.pth"))#保存的训练模型 model.eval()#切换到eval() example = torch.rand(1, 3, 320, 480)#生成一个随机输入维度的输入 traced_script_module = torch.jit.trace(model, example) traced_script_module.save("model.pt")
以上为个人经验,希望能给大家一个参考,也希望大家多多支持我们。如有错误或未考虑完全的地方,望不吝赐教。
相关推荐
-
pytorch获取模型某一层参数名及参数值方式
1.Motivation: I wanna modify the value of some param; I wanna check the value of some param. The needed function: 2.state_dict() #generator type model.modules()#generator type named_parameters()#OrderDict type from torch import nn import torch #creat
-
pytorch 加载(.pth)格式的模型实例
有一些非常流行的网络如 resnet.squeezenet.densenet等在pytorch里面都有,包括网络结构和训练好的模型. pytorch自带模型网址:https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-models/ 按官网加载预训练好的模型: import torchvision.models as models # pretrained=True就可以使用预训练的模型 resnet18 = mod
-
pytorch 实现打印模型的参数值
对于简单的网络 例如全连接层Linear 可以使用以下方法打印linear层: fc = nn.Linear(3, 5) params = list(fc.named_parameters()) print(params.__len__()) print(params[0]) print(params[1]) 输出如下: 由于Linear默认是偏置bias的,所有参数列表的长度是2.第一个存的是全连接矩阵,第二个存的是偏置. 对于稍微复杂的网络 例如MLP mlp = nn.Sequential
-
从Pytorch模型pth文件中读取参数成numpy矩阵的操作
目的: 把训练好的pth模型参数提取出来,然后用其他方式部署到边缘设备. Pytorch给了很方便的读取参数接口: nn.Module.parameters() 直接看demo: from torchvision.models.alexnet import alexnet model = alexnet(pretrained=True).eval().cuda() parameters = model.parameters() for p in parameters: numpy_para =
-
tensorflow实现从.ckpt文件中读取任意变量
思路有些混乱,希望大家能理解我的意思. 看了faster rcnn的tensorflow代码,关于fix_variables的作用我不是很明白,所以写了以下代码,读取了预训练模型vgg16得fc6和fc7的参数,以及faster rcnn中heat_to_tail中的fc6和fc7,将它们做了对比,发现结果不一样,说明vgg16的fc6和fc7只是初始化了faster rcnn中heat_to_tail中的fc6和fc7,之后后者被训练. 具体读取任意变量的代码如下: import tensor
-
PyTorch 模型 onnx 文件导出及调用详情
目录 前言 基本用法 高级 API 前言 Open Neural Network Exchange (ONNX,开放神经网络交换) 格式,是一个用于表示深度学习模型的标准,可使模型在不同框架之间进行转移 PyTorch 所定义的模型为动态图,其前向传播是由类方法定义和实现的 但是 Python 代码的效率是比较底下的,试想把动态图转化为静态图,模型的推理速度应当有所提升 PyTorch 框架中,torch.onnx.export 可以将父类为 nn.Module 的模型导出到 onnx 文件中,
-
Java将对象保存到文件中/从文件中读取对象的方法
1.保存对象到文件中 Java语言只能将实现了Serializable接口的类的对象保存到文件中,利用如下方法即可: public static void writeObjectToFile(Object obj) { File file =new File("test.dat"); FileOutputStream out; try { out = new FileOutputStream(file); ObjectOutputStream objOut=new ObjectOutp
-
从Java的jar文件中读取数据的方法
本文实例讲述了从Java的jar文件中读取数据的方法.分享给大家供大家参考.具体如下: Java 档案 (Java Archive, JAR) 文件是基于 Java 技术的打包方案.它们允许开发人员把所有相关的内容 (.class.图片.声音和支持文件等) 打包到一个单一的文件中.JAR 文件格式支持压缩.身份验证和版本,以及许多其它特性. 从 JAR 文件中得到它所包含的文件内容是件棘手的事情,但也不是不可以做到.这篇技巧就将告诉你如何从 JAR 文件中取得一个文件.我们会先取得这个 JAR
-
Python从文件中读取数据的方法步骤
一.读取整个文件内容 在读取文件之前,我们先创建一个文本文件resource.txt作为源文件. resource.txt my name is joker, I am 18 years old, How about you? 如何读取文件全部内容,我们编写到reader.py文件中. reader.py with open('resource.txt') as file_obj: content = file_obj.read() print(content) 需要注意的是需要将resourc
-
Python从csv文件中读取数据及提取数据的方法
目录 1.从csv文件中读取数据 2.数据切割 数据保存在csv文件中 1.从csv文件中读取数据 参数header=None的有无 (1)没有header=None--直接将csv表中的第一行当作表头 # 读取数据 import pandas as pd data = pd.read_csv("data1.csv") print(data) 打印结果为: (2)有header=None--自动添加第一行当作表头 # 读取数据 import pandas as pd data = pd
-
PHP如何从txt文件中读取数据详解
目录 一.打开/关闭文件 二.读写文件 1.读取整个文件 2.读取一行数据 3.读取一个字符 4.读取任意长度的字符串 总结 一.打开/关闭文件 1.对文件操作时首先要打开文件,打开文件用 fopen()函数,语法是: fopen(filename,mode,include_path,context); 2.对文件操作结束后应该关闭这个文件,使用函数 fclose(); 例如: 二.读写文件 1.读取整个文件 有三个函数可以使用,分别是:readfile()函数.file()函数.file_ge
-
Perl从文件中读取字符串的两种实现方法
1. 一次性将文件中的所有内容读入一个数组中(该方法适合小文件): 复制代码 代码如下: open(FILE,"filename")||die"can not open the file: $!";@filelist=<FILE>; foreach $eachline (@filelist) { chomp $eachline;}close FILE;@filelist=<FILE>; 当文件很大时,可能会出现"out
-
Python3实现从文件中读取指定行的方法
本文实例讲述了Python3实现从文件中读取指定行的方法.分享给大家供大家参考.具体实现方法如下: # Python的标准库linecache模块非常适合这个任务 import linecache the_line = linecache.getline('d:/FreakOut.cpp', 222) print (the_line) # linecache读取并缓存文件中所有的文本, # 若文件很大,而只读一行,则效率低下. # 可显示使用循环, 注意enumerate从0开始计数,而line
随机推荐
- 简单学习vue指令directive
- 硬盘启动提示verifying DMI Pool Data错误的解决方法
- JavaScript类型系统之Object详解
- Wordpress php 分页代码
- Yii框架批量插入数据扩展类的简单实现方法
- go语言制作的zip压缩程序
- 用ASP+FSO生成JS文件
- Android应用自动跳转到应用市场详情页面的方法
- android bitmap compress(图片压缩)代码
- php输出指定时间以前时间格式的方法
- PHP检测字符串是否为UTF8编码的常用方法
- jQuery select自动选中功能实现方法分析
- SqlServer 实用操作小技巧集合第1/2页
- js中的scroll和offset 使用比较的实例与分析
- 完美实现八种js焦点轮播图(上篇)
- nodejs调用cmd命令实现复制目录
- 世界顶级防火墙Look n Stop中文版
- C# SendKeys使用方法介绍
- Android ListView列表控件的介绍和性能优化
- 比较discuz和ecshop的截取字符串函数php版