从Pytorch模型pth文件中读取参数成numpy矩阵的操作

目的:

把训练好的pth模型参数提取出来,然后用其他方式部署到边缘设备。

Pytorch给了很方便的读取参数接口:

nn.Module.parameters()

直接看demo:

from torchvision.models.alexnet import alexnet
model = alexnet(pretrained=True).eval().cuda()
parameters = model.parameters()
for p in parameters:
  numpy_para = p.detach().cpu().numpy()
  print(type(numpy_para))
  print(numpy_para.shape)

上面得到的numpy_para就是numpy参数了~

Note:

model.parameters()是以一个生成器的形式迭代返回每一层的参数。所以用for循环读取到各层的参数,循环次数就表示层数。

而每一层的参数都是torch.nn.parameter.Parameter类型,是Tensor的子类,所以直接用tensor转numpy(即p.detach().cpu().numpy())的方法就可以直接转成numpy矩阵。

方便又好用,爆赞~

补充:pytorch训练好的.pth模型转换为.pt

将python训练好的.pth文件转为.pt

import torch
import torchvision
from unet import UNet
model = UNet(3, 2)#自己定义的网络模型
model.load_state_dict(torch.load("best_weights.pth"))#保存的训练模型
model.eval()#切换到eval()
example = torch.rand(1, 3, 320, 480)#生成一个随机输入维度的输入
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("model.pt")

以上为个人经验,希望能给大家一个参考,也希望大家多多支持我们。如有错误或未考虑完全的地方,望不吝赐教。

(0)

相关推荐

  • pytorch获取模型某一层参数名及参数值方式

    1.Motivation: I wanna modify the value of some param; I wanna check the value of some param. The needed function: 2.state_dict() #generator type model.modules()#generator type named_parameters()#OrderDict type from torch import nn import torch #creat

  • pytorch 实现打印模型的参数值

    对于简单的网络 例如全连接层Linear 可以使用以下方法打印linear层: fc = nn.Linear(3, 5) params = list(fc.named_parameters()) print(params.__len__()) print(params[0]) print(params[1]) 输出如下: 由于Linear默认是偏置bias的,所有参数列表的长度是2.第一个存的是全连接矩阵,第二个存的是偏置. 对于稍微复杂的网络 例如MLP mlp = nn.Sequential

  • pytorch 加载(.pth)格式的模型实例

    有一些非常流行的网络如 resnet.squeezenet.densenet等在pytorch里面都有,包括网络结构和训练好的模型. pytorch自带模型网址:https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-models/ 按官网加载预训练好的模型: import torchvision.models as models # pretrained=True就可以使用预训练的模型 resnet18 = mod

  • 从Pytorch模型pth文件中读取参数成numpy矩阵的操作

    目的: 把训练好的pth模型参数提取出来,然后用其他方式部署到边缘设备. Pytorch给了很方便的读取参数接口: nn.Module.parameters() 直接看demo: from torchvision.models.alexnet import alexnet model = alexnet(pretrained=True).eval().cuda() parameters = model.parameters() for p in parameters: numpy_para =

  • tensorflow实现从.ckpt文件中读取任意变量

    思路有些混乱,希望大家能理解我的意思. 看了faster rcnn的tensorflow代码,关于fix_variables的作用我不是很明白,所以写了以下代码,读取了预训练模型vgg16得fc6和fc7的参数,以及faster rcnn中heat_to_tail中的fc6和fc7,将它们做了对比,发现结果不一样,说明vgg16的fc6和fc7只是初始化了faster rcnn中heat_to_tail中的fc6和fc7,之后后者被训练. 具体读取任意变量的代码如下: import tensor

  • PyTorch 模型 onnx 文件导出及调用详情

    目录 前言 基本用法 高级 API 前言 Open Neural Network Exchange (ONNX,开放神经网络交换) 格式,是一个用于表示深度学习模型的标准,可使模型在不同框架之间进行转移 PyTorch 所定义的模型为动态图,其前向传播是由类方法定义和实现的 但是 Python 代码的效率是比较底下的,试想把动态图转化为静态图,模型的推理速度应当有所提升 PyTorch 框架中,torch.onnx.export 可以将父类为 nn.Module 的模型导出到 onnx 文件中,

  • Java将对象保存到文件中/从文件中读取对象的方法

    1.保存对象到文件中 Java语言只能将实现了Serializable接口的类的对象保存到文件中,利用如下方法即可: public static void writeObjectToFile(Object obj) { File file =new File("test.dat"); FileOutputStream out; try { out = new FileOutputStream(file); ObjectOutputStream objOut=new ObjectOutp

  • 从Java的jar文件中读取数据的方法

    本文实例讲述了从Java的jar文件中读取数据的方法.分享给大家供大家参考.具体如下: Java 档案 (Java Archive, JAR) 文件是基于 Java 技术的打包方案.它们允许开发人员把所有相关的内容 (.class.图片.声音和支持文件等) 打包到一个单一的文件中.JAR 文件格式支持压缩.身份验证和版本,以及许多其它特性. 从 JAR 文件中得到它所包含的文件内容是件棘手的事情,但也不是不可以做到.这篇技巧就将告诉你如何从 JAR 文件中取得一个文件.我们会先取得这个 JAR

  • Python从文件中读取数据的方法步骤

    一.读取整个文件内容 在读取文件之前,我们先创建一个文本文件resource.txt作为源文件. resource.txt my name is joker, I am 18 years old, How about you? 如何读取文件全部内容,我们编写到reader.py文件中. reader.py with open('resource.txt') as file_obj: content = file_obj.read() print(content) 需要注意的是需要将resourc

  • Python从csv文件中读取数据及提取数据的方法

    目录 1.从csv文件中读取数据 2.数据切割 数据保存在csv文件中 1.从csv文件中读取数据 参数header=None的有无 (1)没有header=None--直接将csv表中的第一行当作表头 # 读取数据 import pandas as pd data = pd.read_csv("data1.csv") print(data) 打印结果为: (2)有header=None--自动添加第一行当作表头 # 读取数据 import pandas as pd data = pd

  • PHP如何从txt文件中读取数据详解

    目录 一.打开/关闭文件 二.读写文件 1.读取整个文件 2.读取一行数据 3.读取一个字符 4.读取任意长度的字符串 总结 一.打开/关闭文件 1.对文件操作时首先要打开文件,打开文件用 fopen()函数,语法是: fopen(filename,mode,include_path,context); 2.对文件操作结束后应该关闭这个文件,使用函数 fclose(); 例如: 二.读写文件 1.读取整个文件 有三个函数可以使用,分别是:readfile()函数.file()函数.file_ge

  • Perl从文件中读取字符串的两种实现方法

    1. 一次性将文件中的所有内容读入一个数组中(该方法适合小文件): 复制代码 代码如下: open(FILE,"filename")||die"can not open the file: $!";@filelist=<FILE>; foreach $eachline (@filelist) {        chomp $eachline;}close FILE;@filelist=<FILE>; 当文件很大时,可能会出现"out

  • Python3实现从文件中读取指定行的方法

    本文实例讲述了Python3实现从文件中读取指定行的方法.分享给大家供大家参考.具体实现方法如下: # Python的标准库linecache模块非常适合这个任务 import linecache the_line = linecache.getline('d:/FreakOut.cpp', 222) print (the_line) # linecache读取并缓存文件中所有的文本, # 若文件很大,而只读一行,则效率低下. # 可显示使用循环, 注意enumerate从0开始计数,而line

随机推荐