pytorch实现加载保存查看checkpoint文件

目录
  • 1.保存加载checkpoint文件
  • 2.跨gpu和cpu
  • 3.查看checkpoint文件内容
  • 4.常见问题
  • pytorch保存和加载文件的方法,从断点处继续训练

1.保存加载checkpoint文件

# 方式一:保存加载整个state_dict(推荐)
# 保存
torch.save(model.state_dict(), PATH)
# 加载
model.load_state_dict(torch.load(PATH))
# 测试时不启用 BatchNormalization 和 Dropout
model.eval()
# 方式二:保存加载整个模型
# 保存
torch.save(model, PATH)
# 加载
model = torch.load(PATH)
model.eval()
# 方式三:保存用于继续训练的checkpoint或者多个模型
# 保存
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            ...
            }, PATH)
# 加载
checkpoint = torch.load(PATH)
start_epoch=checkpoint['epoch']
model.load_state_dict(checkpoint['model_state_dict'])
# 测试时
model.eval()
# 或者训练时
model.train()

2.跨gpu和cpu

# GPU上保存,CPU上加载
# 保存
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device('cpu')
model.load_state_dict(torch.load(PATH, map_location=device))
# 如果是多gpu保存,需要去除关键字中的module,见第4部分
# GPU上保存,GPU上加载
# 保存
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device("cuda")
model.load_state_dict(torch.load(PATH))
model.to(device)
# CPU上保存,GPU上加载
# 保存
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device("cuda")
# 选择希望使用的GPU
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  
model.to(device)

3.查看checkpoint文件内容

# 打印模型的 state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

4.常见问题

多gpu

报错为KeyError: ‘unexpected key “module.conv1.weight” in state_dict’

原因:当使用多gpu时,会使用torch.nn.DataParallel,所以checkpoint中有module字样

#解决1:加载时将module去掉

# 创建一个不包含`module.`的新OrderedDict
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # 去掉 `module.`
    new_state_dict[name] = v
# 加载参数
model.load_state_dict(new_state_dict)
# 解决2:保存checkpoint时不保存module
torch.save(model.module.state_dict(), PATH)

pytorch保存和加载文件的方法,从断点处继续训练

'''本文件用于举例说明pytorch保存和加载文件的方法'''
import torch as torch
import torchvision as tv
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import os

# 参数声明
batch_size = 32
epochs = 10
WORKERS = 0  # dataloder线程数
test_flag = False  # 测试标志,True时加载保存好的模型进行测试
ROOT = '/home/pxt/pytorch/cifar'  # MNIST数据集保存路径
log_dir = '/home/pxt/pytorch/logs/cifar_model.pth'  # 模型保存路径
# 加载MNIST数据集
transform = tv.transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

train_data = tv.datasets.CIFAR10(root=ROOT, train=True, download=True, transform=transform)
test_data = tv.datasets.CIFAR10(root=ROOT, train=False, download=False, transform=transform)

train_load = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=WORKERS)
test_load = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=WORKERS)

# 构造模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.conv4 = nn.Conv2d(256, 256, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(256 * 8 * 8, 1024)
        self.fc2 = nn.Linear(1024, 256)
        self.fc3 = nn.Linear(256, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.pool(F.relu(self.conv4(x)))
        x = x.view(-1, x.size()[1] * x.size()[2] * x.size()[3])
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = Net().cpu()

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 模型训练
def train(model, train_loader, epoch):
    model.train()
    train_loss = 0
    for i, data in enumerate(train_loader, 0):
        x, y = data
        x = x.cpu()
        y = y.cpu()

        optimizer.zero_grad()
        y_hat = model(x)
        loss = criterion(y_hat, y)
        loss.backward()
        optimizer.step()
        train_loss += loss
        print('正在进行第{}个epoch中的第{}次循环'.format(epoch,i))

    loss_mean = train_loss / (i + 1)
    print('Train Epoch: {}\t Loss: {:.6f}'.format(epoch, loss_mean.item()))

# 模型测试
def test(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for i, data in enumerate(test_loader, 0):
            x, y = data
            x = x.cpu()
            y = y.cpu()

            optimizer.zero_grad()
            y_hat = model(x)
            test_loss += criterion(y_hat, y).item()
            pred = y_hat.max(1, keepdim=True)[1]
            correct += pred.eq(y.view_as(pred)).sum().item()
        test_loss /= (i + 1)
        print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(test_data), 100. * correct / len(test_data)))

def main():
    # 如果test_flag=True,则加载已保存的模型并进行测试,测试以后不进行此模块以后的步骤
    if test_flag:
        # 加载保存的模型直接进行测试机验证
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        test(model, test_load)
        return

    # 如果有保存的模型,则加载模型,并在其基础上继续训练
    if os.path.exists(log_dir):
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        print('加载 epoch {} 成功!'.format(start_epoch))
    else:
        start_epoch = 0
        print('无保存了的模型,将从头开始训练!')

    for epoch in range(start_epoch+1, epochs):
        train(model, train_load, epoch)
        test(model, test_load)
        # 保存模型
        state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
        torch.save(state, log_dir)

if __name__ == '__main__':
    main()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • pytorch模型的保存和加载、checkpoint操作

    其实之前笔者写代码的时候用到模型的保存和加载,需要用的时候就去度娘搜一下大致代码,现在有时间就来整理下整个pytorch模型的保存和加载,开始学习把~ pytorch的模型和参数是分开的,可以分别保存或加载模型和参数.所以pytorch的保存和加载对应存在两种方式: 1. 直接保存加载模型 (1)保存和加载整个模型 # 保存模型 torch.save(model, 'model.pth\pkl\pt') #一般形式torch.save(net, PATH) # 加载模型 model = torc

  • Pytorch .pth权重文件的使用解析

    pytorch最后的权重文件是.pth格式的. 经常遇到的问题: 进行finutune时,改配置文件中的学习率,发现程序跑起来后竟然保持了以前的学习率, 并没有使用新的学习率. 原因: 首先查看.pth文件中的内容,我们发现它其实是一个字典格式的文件 其中保存了optimizer和scheduler,所以再次加载此文件时会使用之前的学习率. 我们只需要权重,也就是model部分,将其导出就可以了 import torch original = torch.load('path/to/your/c

  • PyTorch模型保存与加载实例详解

    目录 一个简单的例子 保存/加载 state_dict(推荐) 保存/加载整个模型 保存加载用于推理的常规Checkpoint/或继续训练 保存多个模型到一个文件 使用其他模型来预热当前模型 跨设备保存与加载模型 总结 torch.save:保存序列化的对象到磁盘,使用了Python的pickle进行序列化,模型.张量.所有对象的字典. torch.load:使用了pickle的unpacking将pickled的对象反序列化到内存中. torch.nn.Module.load_state_di

  • pytorch实现加载保存查看checkpoint文件

    目录 1.保存加载checkpoint文件 2.跨gpu和cpu 3.查看checkpoint文件内容 4.常见问题 pytorch保存和加载文件的方法,从断点处继续训练 1.保存加载checkpoint文件 # 方式一:保存加载整个state_dict(推荐) # 保存 torch.save(model.state_dict(), PATH) # 加载 model.load_state_dict(torch.load(PATH)) # 测试时不启用 BatchNormalization 和 D

  • Pytorch 数据加载与数据预处理方式

    数据加载分为加载torchvision.datasets中的数据集以及加载自己使用的数据集两种情况. torchvision.datasets中的数据集 torchvision.datasets中自带MNIST,Imagenet-12,CIFAR等数据集,所有的数据集都是torch.utils.data.Dataset的子类,都包含 _ _ len _ (获取数据集长度)和 _ getItem _ _ (获取数据集中每一项)两个子方法. Dataset源码如上,可以看到其中包含了两个没有实现的子

  • jQuery加载及解析XML文件的方法实例分析

    本文实例讲述了jQuery加载及解析XML文件的方法.分享给大家供大家参考,具体如下: 1.简述 XML(eXtensible Markup Language)即可扩展标记语言,与HTML一样,都是属于SGML标准通用语言. 2. Content-Type 很多情况下XML文件不能正常解析都是由于Content-Type没有设置好.如果Content-Type本身就是一个XML文件则不需要设置:如果是由后台程序动态生成的,那么就需要设置Content-Type为"text/xml",否

  • JS实现加载和读取XML文件的方法详解

    本文实例讲述了JS实现加载和读取XML文件的方法.分享给大家供大家参考,具体如下: 有时在开发时用到 JS 加载和读取XML文件的情况,写下提供参考,这里主要是分两步完成: 1. JS加载XML文件 步骤一般为(1),建立 XML DOM 对象:(2),设置加载方式,异步(推荐)或同步: (3)提供XML文件URL然后调用 load 方法:大致如下: var xmlFileName="xxFile.xml"; var xmlDoc=''; if (window.ActiveXObjec

  • Yii2框架加载css和js文件的方法分析

    本文实例讲述了Yii2框架加载css和js文件的方法.分享给大家供大家参考,具体如下: 1.第一步是要把我们的css.js文件放到web目录下 2.第二步修改assets/AppAsset.php文件 <?php /** * @link http://www.yiiframework.com/ * @copyright Copyright (c) 2008 Yii Software LLC * @license http://www.yiiframework.com/license/ */ na

  • Pytorch自己加载单通道图片用作数据集训练的实例

    pytorch 在torchvision包里面有很多的的打包好的数据集,例如minist,Imagenet-12,CIFAR10 和CIFAR100.在torchvision的dataset包里面,用的时候直接调用就行了.具体的调用格式可以去看文档(目前好像只有英文的).网上也有很多源代码. 不过,当我们想利用自己制作的数据集来训练网络模型时,就要有自己的方法了.pytorch在torchvision.dataset包里面封装过一个函数ImageFolder().这个函数功能很强大,只要你直接将

  • pytorch 使用加载训练好的模型做inference

    前提: 模型参数和结构是分别保存的 1. 构建模型(# load model graph) model = MODEL() 2.加载模型参数(# load model state_dict) model.load_state_dict ( { k.replace('module.',''):v for k,v in torch.load(config.model_path, map_location=config.device).items() } ) model = self.model.to

  • java加载属性配置properties文件的方法

    什么是properties文件 属性配置文件,后缀名为 .properties 文件中严格按照key=value进行数据参数的填写 中文参数值需要转为Unicode编码 不区分基本数据类型 一个编辑好的aaa.properties文件如下图所示 username=root flag=true xm =\u4f60\u597d age=18 为什么要使用properties文件 设想这么一种场景,当你项目发布上线后,比如连接mysql数据库的端口号需要调整,难道需要重写改代码,打包,发布么?对于一

  • 基于OpenCv与JVM实现加载保存图像功能(JAVA 图像处理)

    目录 加载图片 保存图片 加载图片 openCv有一个名imread的简单函数,用于从文件中读取图像 imread 函数位于Imgcodecs类的同名包中. 加载图片代码 import org.opencv.core.CvType; import org.opencv.core.Mat; import org.opencv.core.Core; import org.opencv.imgcodecs.Imgcodecs; import origami.Origami; public class

  • 详解Jquery EasyUI tree 的异步加载(遍历指定文件夹,根据文件夹内的文件生成tree)

    Jquery EasyUI tree 的异步加载(遍历指定文件夹,根据文件夹内的文件生成tree)具体代码如下: private void SMT(HttpContext context) { string SqlConnection82 = System.Configuration.ConfigurationManager.AppSettings["LocalConnectionString"]; string path = context.Server.MapPath(@"

随机推荐