pytorch加载自定义网络权重的实现

在将自定义的网络权重加载到网络中时,报错:

AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.

我们一步一步分析。

模型网络权重保存额代码是:torch.save(net.state_dict(),'net.pkl')

(1)查看获取模型权重的源码:

pytorch源码:net.state_dict()

def state_dict(self, destination=None, prefix='', keep_vars=False):
  r"""Returns a dictionary containing a whole state of the module.

  Both parameters and persistent buffers (e.g. running averages) are
  included. Keys are corresponding parameter and buffer names.

  Returns:
    dict:
      a dictionary containing a whole state of the module

  Example::

    >>> module.state_dict().keys()
    ['bias', 'weight']

  """

将网络中所有的状态保存到一个字典中了,我自己构建的就是一个字典,没问题!

(2)查看保存模型权重的源码:

pytorch源码:torch.save()

def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
  """Saves an object to a disk file.

  See also: :ref:`recommend-saving-models`

  Args:
    obj: saved object
    f: a file-like object (has to implement write and flush) or a string
      containing a file name
    pickle_module: module used for pickling metadata and objects
    pickle_protocol: can be specified to override the default protocol

  .. warning::
    If you are using Python 2, torch.save does NOT support StringIO.StringIO
    as a valid file-like object. This is because the write method should return
    the number of bytes written; StringIO.write() does not do this.

    Please use something like io.BytesIO instead.

函数功能是将字典保存为磁盘文件(二进制数据),那么我们在torch.load()时,就是在内存中加载二进制数据,这就是报错点。

解决方案:将字典保存为BytesIO文件之后,模型再net.load_state_dict()

#b为自定义的字典
torch.save(b,'new.pkl')
net.load_state_dict(torch.load(b))

解决方法很简单,主要记录解决思路。

以上这篇pytorch加载自定义网络权重的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • pytorch自定义初始化权重的方法

    在常见的pytorch代码中,我们见到的初始化方式都是调用init类对每层所有参数进行初始化.但是,有时我们有些特殊需求,比如用某一层的权重取优化其它层,或者手动指定某些权重的初始值. 核心思想就是构造和该层权重同一尺寸的矩阵去对该层权重赋值.但是,值得注意的是,pytorch中各层权重的数据类型是nn.Parameter,而不是Tensor或者Variable. import torch import torch.nn as nn import torch.optim as optim imp

  • pytorch动态网络以及权重共享实例

    pytorch 动态网络+权值共享 pytorch以动态图著称,下面以一个栗子来实现动态网络和权值共享技术: # -*- coding: utf-8 -*- import random import torch class DynamicNet(torch.nn.Module): def __init__(self, D_in, H, D_out): """ 这里构造了几个向前传播过程中用到的线性函数 """ super(DynamicNet,

  • Pytorch 实现权重初始化

    在TensorFlow中,权重的初始化主要是在声明张量的时候进行的. 而PyTorch则提供了另一种方法:首先应该声明张量,然后修改张量的权重.通过调用torch.nn.init包中的多种方法可以将权重初始化为直接访问张量的属性. 1.不初始化的效果 在Pytorch中,定义一个tensor,不进行初始化,打印看看结果: w = torch.Tensor(3,4) print (w) 可以看到这时候的初始化的数值都是随机的,而且特别大,这对网络的训练必定不好,最后导致精度提不上,甚至损失无法收敛

  • Pytorch: 自定义网络层实例

    自定义Autograd函数 对于浅层的网络,我们可以手动的书写前向传播和反向传播过程.但是当网络变得很大时,特别是在做深度学习时,网络结构变得复杂.前向传播和反向传播也随之变得复杂,手动书写这两个过程就会存在很大的困难.幸运地是在pytorch中存在了自动微分的包,可以用来解决该问题.在使用自动求导的时候,网络的前向传播会定义一个计算图(computational graph),图中的节点是张量(tensor),两个节点之间的边对应了两个张量之间变换关系的函数.有了计算图的存在,张量的梯度计算也

  • pytorch 自定义数据集加载方法

    pytorch 官网给出的例子中都是使用了已经定义好的特殊数据集接口来加载数据,而且其使用的数据都是官方给出的数据.如果我们有自己收集的数据集,如何用来训练网络呢?此时需要我们自己定义好数据处理接口.幸运的是pytroch给出了一个数据集接口类(torch.utils.data.Dataset),可以方便我们继承并实现自己的数据集接口. torch.utils.data torch的这个文件包含了一些关于数据集处理的类. class torch.utils.data.Dataset: 一个抽象类

  • pytorch加载自定义网络权重的实现

    在将自定义的网络权重加载到网络中时,报错: AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead. 我们一步一步分析. 模型网络权重保存额代码是:torch.sa

  • pytorch加载语音类自定义数据集的方法教程

    前言 pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合 torch.utils.data.Dataset:所有继承他的子类都应该重写  __len()__  , __getitem()__ 这两个方法 __len()__ :返回数据集中数据的数量 __getitem()__ :返回支持下标索引方式获取的一个数据 torch.utils.data.DataLoad

  • Pytorch加载部分预训练模型的参数实例

    前言 自从从深度学习框架caffe转到Pytorch之后,感觉Pytorch的优点妙不可言,各种设计简洁,方便研究网络结构修改,容易上手,比TensorFlow的臃肿好多了.对于深度学习的初学者,Pytorch值得推荐.今天主要主要谈谈Pytorch是如何加载预训练模型的参数以及代码的实现过程. 直接加载预选脸模型 如果我们使用的模型和预训练模型完全一样,那么我们就可以直接加载别人的模型,还有一种情况,我们在训练自己模型的过程中,突然中断了,但只要我们保存了之前的模型的参数也可以使用下面的代码直

  • pytorch 加载(.pth)格式的模型实例

    有一些非常流行的网络如 resnet.squeezenet.densenet等在pytorch里面都有,包括网络结构和训练好的模型. pytorch自带模型网址:https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-models/ 按官网加载预训练好的模型: import torchvision.models as models # pretrained=True就可以使用预训练的模型 resnet18 = mod

  • 解决Pytorch 加载训练好的模型 遇到的error问题

    这是一个非常愚蠢的错误 debug的时候要好好看error信息 提醒自己切记好好对待error!切记!切记! -----------------------分割线---------------- pytorch 已经非常友好了 保存模型和加载模型都只需要一条简单的命令 #保存整个网络和参数 torch.save(your_net, 'save_name.pkl') #加载保存的模型 net = torch.load('save_name.pkl') 因为我比较懒我就想直接把整个网络都保存下来,然

  • PyTorch加载自己的数据集实例详解

    数据预处理在解决深度学习问题的过程中,往往需要花费大量的时间和精力. 数据处理的质量对训练神经网络来说十分重要,良好的数据处理不仅会加速模型训练, 更会提高模型性能.为解决这一问题,PyTorch提供了几个高效便捷的工具, 以便使用者进行数据处理或增强等操作,同时可通过并行化加速数据加载. 数据集存放大致有以下两种方式: (1)所有数据集放在一个目录下,文件名上附有标签名,数据集存放格式如下: root/cat_dog/cat.01.jpg root/cat_dog/cat.02.jpg ...

  • pytorch加载自己的图像数据集实例

    之前学习深度学习算法,都是使用网上现成的数据集,而且都有相应的代码.到了自己开始写论文做实验,用到自己的图像数据集的时候,才发现无从下手 ,相信很多新手都会遇到这样的问题. 参考文章https://www.jb51.net/article/177613.htm 下面代码实现了从文件夹内读取所有图片,进行归一化和标准化操作并将图片转化为tensor.最后读取第一张图片并显示. # 数据处理 import os import torch from torch.utils import data fr

  • pytorch加载自己的图片数据集的2种方法详解

    目录 ImageFolder 加载数据集 使用pytorch提供的Dataset类创建自己的数据集. Dataset加载数据集 总结 pytorch加载图片数据集有两种方法. 1.ImageFolder 适合于分类数据集,并且每一个类别的图片在同一个文件夹, ImageFolder加载的数据集, 训练数据为文件件下的图片, 训练标签是对应的文件夹, 每个文件夹为一个类别 导入ImageFolder()包 from torchvision.datasets import ImageFolder 在

  • pytorch加载自己的数据集源码分享

    目录 一.标准的数据集流程梳理 数据来源 二.实现加载自己的数据集 1. 保存在txt文件中(生成训练集和测试集,其实这里的训练集以及测试集也都是用文本文件的形式保存下来的) 2. 在继承dataset类LoadData的三个函数里调用train.txt以及test.txt实现相关功能 三.源码 一.标准的数据集流程梳理 分为几个步骤数据准备以及加载数据库–>数据加载器的调用或者设计–>批量调用进行训练或者其他作用 数据来源 直接读取了x和y的数据变量,对比后面的就从把对应的路径写进了文本文件

  • Pytorch加载数据集的方式总结及补充

    目录 前言 一.自己重写定义(Dataset.DataLoader) 二.用Pytorch自带的类(ImageFolder.datasets.DataLoader) 2.1 加载自己的数据集 2.1.1 ImageFolder介绍 2.2.2 ImageFolder加载数据集完整例子 2.2 加载常见的数据集 三.总结 四.transforms变换讲解 五.DataLoader的补充 总结 前言 在用Pytorch加载数据集时,看GitHub上的代码经常会用到ImageFolder.DataLo

随机推荐