从Pytorch模型pth文件中读取参数成numpy矩阵的操作
目的:
把训练好的pth模型参数提取出来,然后用其他方式部署到边缘设备。
Pytorch给了很方便的读取参数接口:
nn.Module.parameters()
直接看demo:
from torchvision.models.alexnet import alexnet model = alexnet(pretrained=True).eval().cuda() parameters = model.parameters() for p in parameters: numpy_para = p.detach().cpu().numpy() print(type(numpy_para)) print(numpy_para.shape)
上面得到的numpy_para就是numpy参数了~
Note:
model.parameters()是以一个生成器的形式迭代返回每一层的参数。所以用for循环读取到各层的参数,循环次数就表示层数。
而每一层的参数都是torch.nn.parameter.Parameter类型,是Tensor的子类,所以直接用tensor转numpy(即p.detach().cpu().numpy())的方法就可以直接转成numpy矩阵。
方便又好用,爆赞~
补充:pytorch训练好的.pth模型转换为.pt
将python训练好的.pth文件转为.pt
import torch import torchvision from unet import UNet model = UNet(3, 2)#自己定义的网络模型 model.load_state_dict(torch.load("best_weights.pth"))#保存的训练模型 model.eval()#切换到eval() example = torch.rand(1, 3, 320, 480)#生成一个随机输入维度的输入 traced_script_module = torch.jit.trace(model, example) traced_script_module.save("model.pt")
以上为个人经验,希望能给大家一个参考,也希望大家多多支持我们。如有错误或未考虑完全的地方,望不吝赐教。
相关推荐
-
pytorch获取模型某一层参数名及参数值方式
1.Motivation: I wanna modify the value of some param; I wanna check the value of some param. The needed function: 2.state_dict() #generator type model.modules()#generator type named_parameters()#OrderDict type from torch import nn import torch #creat
-
pytorch 实现打印模型的参数值
对于简单的网络 例如全连接层Linear 可以使用以下方法打印linear层: fc = nn.Linear(3, 5) params = list(fc.named_parameters()) print(params.__len__()) print(params[0]) print(params[1]) 输出如下: 由于Linear默认是偏置bias的,所有参数列表的长度是2.第一个存的是全连接矩阵,第二个存的是偏置. 对于稍微复杂的网络 例如MLP mlp = nn.Sequential
-
pytorch 加载(.pth)格式的模型实例
有一些非常流行的网络如 resnet.squeezenet.densenet等在pytorch里面都有,包括网络结构和训练好的模型. pytorch自带模型网址:https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-models/ 按官网加载预训练好的模型: import torchvision.models as models # pretrained=True就可以使用预训练的模型 resnet18 = mod
-
从Pytorch模型pth文件中读取参数成numpy矩阵的操作
目的: 把训练好的pth模型参数提取出来,然后用其他方式部署到边缘设备. Pytorch给了很方便的读取参数接口: nn.Module.parameters() 直接看demo: from torchvision.models.alexnet import alexnet model = alexnet(pretrained=True).eval().cuda() parameters = model.parameters() for p in parameters: numpy_para =
-
tensorflow实现从.ckpt文件中读取任意变量
思路有些混乱,希望大家能理解我的意思. 看了faster rcnn的tensorflow代码,关于fix_variables的作用我不是很明白,所以写了以下代码,读取了预训练模型vgg16得fc6和fc7的参数,以及faster rcnn中heat_to_tail中的fc6和fc7,将它们做了对比,发现结果不一样,说明vgg16的fc6和fc7只是初始化了faster rcnn中heat_to_tail中的fc6和fc7,之后后者被训练. 具体读取任意变量的代码如下: import tensor
-
PyTorch 模型 onnx 文件导出及调用详情
目录 前言 基本用法 高级 API 前言 Open Neural Network Exchange (ONNX,开放神经网络交换) 格式,是一个用于表示深度学习模型的标准,可使模型在不同框架之间进行转移 PyTorch 所定义的模型为动态图,其前向传播是由类方法定义和实现的 但是 Python 代码的效率是比较底下的,试想把动态图转化为静态图,模型的推理速度应当有所提升 PyTorch 框架中,torch.onnx.export 可以将父类为 nn.Module 的模型导出到 onnx 文件中,
-
Java将对象保存到文件中/从文件中读取对象的方法
1.保存对象到文件中 Java语言只能将实现了Serializable接口的类的对象保存到文件中,利用如下方法即可: public static void writeObjectToFile(Object obj) { File file =new File("test.dat"); FileOutputStream out; try { out = new FileOutputStream(file); ObjectOutputStream objOut=new ObjectOutp
-
从Java的jar文件中读取数据的方法
本文实例讲述了从Java的jar文件中读取数据的方法.分享给大家供大家参考.具体如下: Java 档案 (Java Archive, JAR) 文件是基于 Java 技术的打包方案.它们允许开发人员把所有相关的内容 (.class.图片.声音和支持文件等) 打包到一个单一的文件中.JAR 文件格式支持压缩.身份验证和版本,以及许多其它特性. 从 JAR 文件中得到它所包含的文件内容是件棘手的事情,但也不是不可以做到.这篇技巧就将告诉你如何从 JAR 文件中取得一个文件.我们会先取得这个 JAR
-
Python从文件中读取数据的方法步骤
一.读取整个文件内容 在读取文件之前,我们先创建一个文本文件resource.txt作为源文件. resource.txt my name is joker, I am 18 years old, How about you? 如何读取文件全部内容,我们编写到reader.py文件中. reader.py with open('resource.txt') as file_obj: content = file_obj.read() print(content) 需要注意的是需要将resourc
-
Python从csv文件中读取数据及提取数据的方法
目录 1.从csv文件中读取数据 2.数据切割 数据保存在csv文件中 1.从csv文件中读取数据 参数header=None的有无 (1)没有header=None--直接将csv表中的第一行当作表头 # 读取数据 import pandas as pd data = pd.read_csv("data1.csv") print(data) 打印结果为: (2)有header=None--自动添加第一行当作表头 # 读取数据 import pandas as pd data = pd
-
PHP如何从txt文件中读取数据详解
目录 一.打开/关闭文件 二.读写文件 1.读取整个文件 2.读取一行数据 3.读取一个字符 4.读取任意长度的字符串 总结 一.打开/关闭文件 1.对文件操作时首先要打开文件,打开文件用 fopen()函数,语法是: fopen(filename,mode,include_path,context); 2.对文件操作结束后应该关闭这个文件,使用函数 fclose(); 例如: 二.读写文件 1.读取整个文件 有三个函数可以使用,分别是:readfile()函数.file()函数.file_ge
-
Perl从文件中读取字符串的两种实现方法
1. 一次性将文件中的所有内容读入一个数组中(该方法适合小文件): 复制代码 代码如下: open(FILE,"filename")||die"can not open the file: $!";@filelist=<FILE>; foreach $eachline (@filelist) { chomp $eachline;}close FILE;@filelist=<FILE>; 当文件很大时,可能会出现"out
-
Python3实现从文件中读取指定行的方法
本文实例讲述了Python3实现从文件中读取指定行的方法.分享给大家供大家参考.具体实现方法如下: # Python的标准库linecache模块非常适合这个任务 import linecache the_line = linecache.getline('d:/FreakOut.cpp', 222) print (the_line) # linecache读取并缓存文件中所有的文本, # 若文件很大,而只读一行,则效率低下. # 可显示使用循环, 注意enumerate从0开始计数,而line
随机推荐
- React入门教程之Hello World以及环境搭建详解
- 解决Ubuntu 16.04下提示boot分区空间不足的办法
- java解析xml常用的几种方式总结
- python实现简单的socket server实例
- JavaScript笔记之数据属性和存储器属性
- Bootstrap每天必学之模态框(Modal)插件
- asp.net 在DNN模块开发中遇到的resx怪问题
- php array_pop()数组函数将数组最后一个单元弹出(出栈)
- 深入解析C#编程中struct所定义的结构
- php中的观察者模式
- python fabric使用笔记
- jquery制作多功能轮播图插件
- 计算字符串和文件MD5值的小例子
- 跟老齐学Python之集成开发环境(IDE)
- CSS中常用的单位
- 更新路由表
- Python爬豆瓣电影实例
- webpack 代码分离优化快速指北
- Python UnboundLocalError和NameError错误根源案例解析
- 微信小程序项目总结之记账小程序功能的实现(包括后端)