PyTorch 多GPU下模型的保存与加载(踩坑笔记)

这几天在一机多卡的环境下,用pytorch训练模型,遇到很多问题。现总结一个实用的做实验方式:

多GPU下训练,创建模型代码通常如下:

os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda
model = MyModel(args)
if torch.cuda.is_available() and args.use_gpu:
  model = torch.nn.DataParallel(model).cuda()

官方建议的模型保存方式,只保存参数:

torch.save(model.module.state_dict(), "model.pkl")

其实,这样很麻烦,我建议直接保存模型(参数+图):

torch.save(model, "model.pkl")

这样做很实用,特别是我们需要反复建模和调试的时候。这种情况下模型的加载很方便,因为模型的图已经和参数保存在一起,我们不需要根据不同的模型设置相应的超参,更换对应的网络结构,如下:

 if not (args.pretrained_model_path is None):
    print('load model from %s ...' % args.pretrained_model_path)
    model = torch.load(args.pretrained_model_path)
    print('success!')

但是需要注意,这种方式加载的是多GPU下模型。如果服务器环境变化不大,或者和训练时候是同一个GPU环境,就不会出现问题。

如果系统环境发生了变化,或者,我们只想加载模型参数,亦或是遇到下面的问题:

AttributeError: 'model' object has no attribute 'copy'

或者

AttributeError: 'DataParallel' object has no attribute 'copy'

或者

RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found

这时候我们可以用下面的方式载入模型,先建立模型,然后加载参数。

os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda
# 建立模型
model = MyModel(args)

if torch.cuda.is_available() and args.use_gpu:
  model = torch.nn.DataParallel(model).cuda()

if not (args.pretrained_model_path is None):
  print('load model from %s ...' % args.pretrained_model_path)
  # 获得模型参数
  model_dict = torch.load(args.pretrained_model_path).module.state_dict()
  # 载入参数
  model.module.load_state_dict(model_dict)
  print('success!')

到此这篇关于PyTorch 多GPU下模型的保存与加载(踩坑笔记)的文章就介绍到这了,更多相关PyTorch 多GPU下模型的保存与加载内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • 将Pytorch模型从CPU转换成GPU的实现方法

    最近将Pytorch程序迁移到GPU上去的一些工作和思考 环境:Ubuntu 16.04.3 Python版本:3.5.2 Pytorch版本:0.4.0 0. 序言 大家知道,在深度学习中使用GPU来对模型进行训练是可以通过并行化其计算来提高运行效率,这里就不多谈了. 最近申请到了实验室的服务器来跑程序,成功将我简陋的程序改成了"高大上"GPU版本. 看到网上总体来说少了很多介绍,这里决定将我的一些思考和工作记录下来. 1. 如何进行迁移 由于我使用的是Pytorch写的模型,网上给

  • 解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题

    背景 在公司用多卡训练模型,得到权值文件后保存,然后回到实验室,没有多卡的环境,用单卡训练,加载模型时出错,因为单卡机器上,没有使用DataParallel来加载模型,所以会出现加载错误. 原因 DataParallel包装的模型在保存时,权值参数前面会带有module字符,然而自己在单卡环境下,没有用DataParallel包装的模型权值参数不带module.本质上保存的权值文件是一个有序字典. 解决方法 1.在单卡环境下,用DataParallel包装模型. 2.自己重写Load函数,灵活.

  • PyTorch 多GPU下模型的保存与加载(踩坑笔记)

    这几天在一机多卡的环境下,用pytorch训练模型,遇到很多问题.现总结一个实用的做实验方式: 多GPU下训练,创建模型代码通常如下: os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda model = MyModel(args) if torch.cuda.is_available() and args.use_gpu: model = torch.nn.DataParallel(model).cuda() 官方建议的模型保存方式,只保存参数: tor

  • PyTorch模型的保存与加载方法实例

    目录 模型的保存与加载 保存和加载模型参数 保存和加载模型参数与结构 总结 模型的保存与加载 首先,需要导入两个包 import torch import torchvision.models as models 保存和加载模型参数 PyTorch模型将学习到的参数存储在一个内部状态字典中,叫做state_dict.这可以通过torch.save方法来实现.我们导入预训练好的VGG16模型,并将其保存.我们将state_dict字典保存在model_weights.pth文件中. model =

  • pytorch模型保存与加载中的一些问题实战记录

    目录 前言 一.torch中模型保存和加载的方式 1.模型参数和模型结构保存和加载 2.只保存模型的参数和加载——这种方式比较安全,但是比较稍微麻烦一点点 二.torch中模型保存和加载出现的问题 1.单卡模型下保存模型结构和参数后加载出现的问题 2.多卡机器单卡训练模型保存后在单卡机器上加载会报错 3.多卡训练模型保存模型结构和参数后加载出现的问题 三.正确的保存模型和加载的方法 总结 前言 最近使用pytorch训练模型,保存模型后再次加载使用出现了一些问题.记录一下解决方案! 一.torc

  • PyTorch模型保存与加载实例详解

    目录 一个简单的例子 保存/加载 state_dict(推荐) 保存/加载整个模型 保存加载用于推理的常规Checkpoint/或继续训练 保存多个模型到一个文件 使用其他模型来预热当前模型 跨设备保存与加载模型 总结 torch.save:保存序列化的对象到磁盘,使用了Python的pickle进行序列化,模型.张量.所有对象的字典. torch.load:使用了pickle的unpacking将pickled的对象反序列化到内存中. torch.nn.Module.load_state_di

  • Tensorflow之MNIST CNN实现并保存、加载模型

    本文实例为大家分享了Tensorflow之MNIST CNN实现并保存.加载模型的具体代码,供大家参考,具体内容如下 废话不说,直接上代码 # TensorFlow and tf.keras import tensorflow as tf from tensorflow import keras # Helper libraries import numpy as np import matplotlib.pyplot as plt import os #download the data mn

  • tensorflow模型保存、加载之变量重命名实例

    话不多说,干就完了. 变量重命名的用处? 简单定义:简单来说就是将模型A中的参数parameter_A赋给模型B中的parameter_B 使用场景:当需要使用已经训练好的模型参数,尤其是使用别人训练好的模型参数时,往往别人模型中的参数命名方式与自己当前的命名方式不同,所以在加载模型参数时需要对参数进行重命名,使得代码更简洁易懂. 实现方法: 1).模型保存 import os import tensorflow as tf weights = tf.Variable(initial_value

  • Android RecyclerView实现下拉刷新和上拉加载

    RecyclerView已经出来很久了,许许多多的项目都开始从ListView转战RecyclerView,那么,上拉加载和下拉刷新是一件很有必要的事情. 在ListView上,我们可以通过自己添加addHeadView和addFootView去添加头布局和底部局实现自定义的上拉和下拉,或者使用一些第三方库来简单的集成,例如Android-pulltorefresh或者android-Ultra-Pull-to-Refresh,后者的自定义更强,但需要自己实现上拉加载. 而在下面我们将用两种方式

  • RecyclerView下拉刷新上拉加载

    一 .前言 最近实在太忙,一个多礼拜没有更新文章了,于是今晚加班加点把demo写出来,现在都12点了才开始写文章. 1.我们的目标 把RecyclerView下拉刷新上拉加载更多加入到我们的开发者头条APP中. 2.效果图 3.实现步骤 找一个带上拉刷新下载加载更多的RecyclerView开源库,我们要站在巨人的肩膀上 下载下来自己先运行下demo,然后看看是不是我们需要的功能,觉得不错就把module依赖进来,整合主项目. 整合进来了之后,我们肯定需要进行修改,例如我这边就有滑动冲突,有多个

  • JS+CSS实现下拉刷新/上拉加载插件

    闲来无事,写了一个当下比较常见的下拉刷新/上拉加载的jquery插件,代码记录在这里,有兴趣将代码写成插件与npm包可以留言. 体验地址:http://owenliang.github.io/pullToRefresh/ 项目地址:https://github.com/owenliang/pullToRefresh 实现注意: 利用transition做动画时,优先使用transform:translate取代top,后者动画流畅度存在问题. 各移动浏览器对手势触摸的处理不同(简单罗列如下),但

  • 使用MUI框架模拟手机端的下拉刷新和上拉加载功能

    mui框架基于htm5plus的XMLHttpRequest,封装了常用的Ajax函数,支持GET.POST请求方式,支持返回json.xml.html.text.script数据类型: 本着极简的设计原则,mui提供了mui.ajax方法,并在mui.ajax方法基础上,进一步简化出最常用的mui.get().mui.getJSON().mui.post()三个方法. 套用mui官方文档的一句话:"开发者只需关心业务逻辑,实现加载更多数据即可".真的是不错的框架. 想更多的了解这个框

随机推荐