Pytorch保存模型用于测试和用于继续训练的区别详解

保存模型

保存模型仅仅是为了测试的时候,只需要

torch.save(model.state_dict, path)

path 为保存的路径

但是有时候模型及数据太多,难以一次性训练完的时候,而且用的还是 Adam优化器的时候, 一定要保存好训练的优化器参数以及epoch

state = { 'model': model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch': epoch }
torch.save(state, path)

因为这里

def adjust_learning_rate(optimizer, epoch):
  lr_t = lr
  lr_t = lr_t * (0.3 ** (epoch // 2))
  for param_group in optimizer.param_groups:
    param_group['lr'] = lr_t

学习率是根据epoch变化的, 如果不保存epoch的话,基本上每次都从epoch为0开始训练,这样学习率就相当于不变了!!

恢复模型

恢复模型只用于测试的时候,

model.load_state_dict(torch.load(path))

path为之前存储模型时的路径

但是如果是用于继续训练的话,

checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']+1

依次恢复出模型 优化器参数以及epoch

以上这篇Pytorch保存模型用于测试和用于继续训练的区别详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • 基于pytorch的保存和加载模型参数的方法

    当我们花费大量的精力训练完网络,下次预测数据时不想再(有时也不必再)训练一次时,这时候torch.save(),torch.load()就要登场了. 保存和加载模型参数有两种方式: 方式一: torch.save(net.state_dict(),path): 功能:保存训练完的网络的各层参数(即weights和bias) 其中:net.state_dict()获取各层参数,path是文件存放路径(通常保存文件格式为.pt或.pth) net2.load_state_dict(torch.loa

  • Pytorch之保存读取模型实例

    pytorch保存数据 pytorch保存数据的格式为.t7文件或者.pth文件,t7文件是沿用torch7中读取模型权重的方式.而pth文件是python中存储文件的常用格式.而在keras中则是使用.h5文件. # 保存模型示例代码 print('===> Saving models...') state = { 'state': model.state_dict(), 'epoch': epoch # 将epoch一并保存 } if not os.path.isdir('checkpoin

  • Pytorch提取模型特征向量保存至csv的例子

    Pytorch提取模型特征向量 # -*- coding: utf-8 -*- """ dj """ import torch import torch.nn as nn import os from torchvision import models, transforms from torch.autograd import Variable import numpy as np from PIL import Image import to

  • Pytorch 保存模型生成图片方式

    三通道数组转成彩色图片 img=np.array(img1) img=img.reshape(3,img1.shape[2],img1.shape[3]) img=(img+0.5)*255##img做过归一化处理,[-0.5,0.5] img_path='/home/isee/wei/image/imageset/result.jpg' img=cv2.merge(img) cv2.imwrite(img_path,img) 单通道数组转化成灰度图 Img_mask=np.array(img_

  • Pytorch保存模型用于测试和用于继续训练的区别详解

    保存模型 保存模型仅仅是为了测试的时候,只需要 torch.save(model.state_dict, path) path 为保存的路径 但是有时候模型及数据太多,难以一次性训练完的时候,而且用的还是 Adam优化器的时候, 一定要保存好训练的优化器参数以及epoch state = { 'model': model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch': epoch } torch.save(state, pat

  • Java内存模型与JVM运行时数据区的区别详解

    首先,这两者是完全不同的概念,绝对不能混为一谈. 1.什么是Java内存模型? Java内存模型是Java语言在多线程并发情况下对于共享变量读写(实际是共享变量对应的内存操作)的规范,主要是为了解决多线程可见性.原子性的问题,解决共享变量的多线程操作冲突问题. 多线程编程的普遍问题是: 所见非所得 无法肉眼检测程序的准确性 不同的运行平台表现不同 错误很难复现 故JVM规范规定了Java虚拟机对多线程内存操作的一些规则,主要集中体现在volatile和synchronized这两个关键字. vo

  • 解决pytorch 保存模型遇到的问题

    今天用pytorch保存模型时遇到bug Can't pickle <class 'torch._C._VariableFunctions'> 在google上查找原因,发现是保存时保存了整个模型的原因,而模型中有一些自定义的参数 将 torch.save(model,save_path) 改为 torch.save(model.state_dict(),save_path) 然后载入模型也做相应的更改就好了 补充:pytorch训练模型的一些坑 1. 图像读取 opencv的python和c

  • 在Pytorch中计算卷积方法的区别详解(conv2d的区别)

    在二维矩阵间的运算: class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True) 对由多个特征平面组成的输入信号进行2D的卷积操作.详解 torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1)

  • pytorch:model.train和model.eval用法及区别详解

    使用PyTorch进行训练和测试时一定注意要把实例化的model指定train/eval,eval()时,框架会自动把BN和DropOut固定住,不会取平均,而是用训练好的值,不然的话,一旦test的batch_size过小,很容易就会被BN层导致生成图片颜色失真极大!!!!!! Class Inpaint_Network() ...... Model = Inpaint_Nerwoek() #train: Model.train(mode=True) ..... #test: Model.ev

  • Python测试框架pytest高阶用法全面详解

    目录 前言 1.pytest安装 1.1安装 1.2验证安装 1.3pytest文档 1.4 Pytest运行方式 1.5 Pytest Exit Code 含义清单 1.6 如何获取帮助信息 1.7 控制测试用例执行 1.8 多进程运行cases 1.9 重试运行cases 1.10 显示print内容 2.Pytest的setup和teardown函数 函数级别setup()/teardown() 类级别 3.Pytest配置文件 4 Pytest常用插件 4.1 前置条件: 4.2 Pyt

  • 在Django下测试与调试REST API的方法详解

    对于大多数研发人员来说,都期望能找到一个良好的测试/调试方法,来提高工作效率和快速解决问题.所谓调试,偏重于对某个bug的查找.定位.修复:所谓测试,是检验某个功能是否达到预期效果.测试发现问题后进行调试,从而解决问题. 对于后台研发来说,往往没有客户端研发(Windows/Android等等)那样简单有效的DEBUG方法,比如Step by Step.虽然目前有很多IDE可以实现本地调试,但是因为后台研发的环境复杂,你很难在一台机器上模拟所有的环境,比如线上的数据库只能在内网访问等等,所以很多

  • golang使用 gomodule 在公共测试环境管理go的依赖的实例详解

    背景:调试服务最好的方式就是直接上机实践.对在公司的员工来说,在同一套服务上协同开发比在单独的环境上开发,应该会更有感觉.有问题可以一起发现并解决,也能够一同开发需求. 但是,公司的测试机往往是没办法连外网的,而golang 的大部分工程都需要直接从github 上下载依赖,这就导致 依赖文件需要先提前上传到开发机上.那么当开发机上需要运行多个golang 工程的时候,如何共享这些依赖,减少维护依赖库的工作量呢? 这也是需要大家协作完成的- 最终总结:项目采用 go module + vendo

  • PyTorch中 tensor.detach() 和 tensor.data 的区别详解

    PyTorch0.4中,.data 仍保留,但建议使用 .detach(), 区别在于 .data 返回和 x 的相同数据 tensor, 但不会加入到x的计算历史里,且require s_grad = False, 这样有些时候是不安全的, 因为 x.data 不能被 autograd 追踪求微分 . .detach() 返回相同数据的 tensor ,且 requires_grad=False ,但能通过 in-place 操作报告给 autograd 在进行反向传播的时候. 举例: ten

随机推荐