关于Pytorch中模型的保存与迁移问题

目录
  • 1 引言
  • 2 模型的保存与复用
    • 2.1 查看网络模型参数
    • 2.2 载入模型进行推断
    • 2.3 载入模型进行训练
    • 2.4 载入模型进行迁移
  • 3 总结

1 引言

各位朋友大家好,欢迎来到月来客栈。今天要和大家介绍的内容是如何在Pytorch框架中对模型进行保存和载入、以及模型的迁移和再训练。一般来说,最常见的场景就是模型完成训练后的推断过程。一个网络模型在完成训练后通常都需要对新样本进行预测,此时就只需要构建模型的前向传播过程,然后载入已训练好的参数初始化网络即可。

第2个场景就是模型的再训练过程。一个模型在一批数据上训练完成之后需要将其保存到本地,并且可能过了一段时间后又收集到了一批新的数据,因此这个时候就需要将之前的模型载入进行在新数据上进行增量训练(或者是在整个数据上进行全量训练)。

第3个应用场景就是模型的迁移学习。这个时候就是将别人已经训练好的预模型拿过来,作为你自己网络模型参数的一部分进行初始化。例如:你自己在Bert模型的基础上加了几个全连接层来做分类任务,那么你就需要将原始BERT模型中的参数载入并以此来初始化你的网络中的BERT部分的权重参数。

在接下来的这篇文章中,笔者就以上述3个场景为例来介绍如何利用Pytorch框架来完成上述过程。

2 模型的保存与复用

在Pytorch中,我们可以通过torch.save()torch.load()来完成上述场景中的主要步骤。下面,笔者将以之前介绍的LeNet5网络模型为例来分别进行介绍。不过在这之前,我们先来看看Pytorch中模型参数的保存形式。

2.1 查看网络模型参数

(1)查看参数

首先定义好LeNet5的网络模型结构,如下代码所示:

class LeNet5(nn.Module):
    def __init__(self, ):
        super(LeNet5, self).__init__()
        self.conv = nn.Sequential(  # [n,1,28,28]
            nn.Conv2d(1, 6, 5, padding=2),  # in_channels, out_channels, kernel_size
            nn.ReLU(),  # [n,6,24,24]
            nn.MaxPool2d(2, 2),  # kernel_size, stride  [n,6,14,14]
            nn.Conv2d(6, 16, 5),  # [n,16,10,10]
            nn.ReLU(),
            nn.MaxPool2d(2, 2))  # [n,16,5,5]
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(16 * 5 * 5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, 10))
    def forward(self, img):
        output = self.conv(img)
        output = self.fc(output)
        return output

在定义好LeNet5这个网络结构的类之后,只要我们完成了这个类的实例化操作,那么网络中对应的权重参数也都完成了初始化的工作,即有了一个初始值。同时,我们可以通过如下方式来访问:

# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

其输出的结果为:

conv.0.weight torch.Size([6, 1, 5, 5])
conv.0.bias torch.Size([6])
conv.3.weight torch.Size([16, 6, 5, 5])
....
....

可以发现,网络模型中的参数model.state_dict()其实是以字典的形式(实质上是collections模块中的OrderedDict)保存下来的:

print(model.state_dict().keys())
# odict_keys(['conv.0.weight', 'conv.0.bias', 'conv.3.weight', 'conv.3.bias', 'fc.1.weight', 'fc.1.bias', 'fc.3.weight', 'fc.3.bias', 'fc.5.weight', 'fc.5.bias'])

(2)自定义参数前缀

同时,这里值得注意的地方有两点:①参数名中的fcconv前缀是根据你在上面定义nn.Sequential()时的名字所确定的;②参数名中的数字表示每个Sequential()中网络层所在的位置。例如将网络结构定义成如下形式:

class LeNet5(nn.Module):
    def __init__(self, ):
        super(LeNet5, self).__init__()
        self.moon = nn.Sequential(  # [n,1,28,28]
            nn.Conv2d(1, 6, 5, padding=2),  # in_channels, out_channels, kernel_size
            nn.ReLU(),  # [n,6,24,24]
            nn.MaxPool2d(2, 2),  # kernel_size, stride  [n,6,14,14]
            nn.Conv2d(6, 16, 5),  # [n,16,10,10]
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Flatten(),
            nn.Linear(16 * 5 * 5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, 10))

那么其参数名则为:

print(model.state_dict().keys())
odict_keys(['moon.0.weight', 'moon.0.bias', 'moon.3.weight', 'moon.3.bias', 'moon.7.weight', 'moon.7.bias', 'moon.9.weight', 'moon.9.bias', 'moon.11.weight', 'moon.11.bias'])

理解了这一点对于后续我们去解析和载入一些预训练模型很有帮助。

除此之外,对于中的优化器等,其同样有对应的state_dict()方法来获取对于的参数,例如:

optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
   print(var_name, "\t", optimizer.state_dict()[var_name])

#
Optimizer's state_dict:
state   {}
param_groups   [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140239245300504, 140239208339784, 140239245311360, 140239245310856, 140239266942480, 140239266942552, 140239266942624, 140239266942696, 140239266942912, 140239267041352]}]

在介绍完模型参数的查看方法后,就可以进入到模型复用阶段的内容介绍了。

2.2 载入模型进行推断

(1) 模型保存

在Pytorch中,对于模型的保存来说是非常简单的,通常来说通过如下两行代码便可以实现:

model_save_path = os.path.join(model_save_dir, 'model.pt')
torch.save(model.state_dict(), model_save_path)

在指定保存的模型名称时Pytorch官方建议的后缀为.pt或者.pth(当然也不是强制的)。最后,只需要在合适的地方加入第2行代码即可完成模型的保存。

同时,如果想要在训练过程中保存某个条件下的最优模型,那么应该通过如下方式:

best_model_state = deepcopy(model.state_dict())
torch.save(best_model_state, model_save_path)

而不是:

best_model_state = model.state_dict()
torch.save(best_model_state, model_save_path)

因为后者best_model_state得到只是model.state_dict()的引用,它依旧会随着训练过程而发生改变。

(2)复用模型进行推断

在推断过程中,首先需要完成网络的初始化,然后再载入已有的模型参数来覆盖网络中的权重参数即可,示例代码如下:

def inference(data_iter, device, model_save_dir='./MODEL'):
    model = LeNet5()  # 初始化现有模型的权重参数
    model.to(device)
    model_save_path = os.path.join(model_save_dir, 'model.pt')
    if os.path.exists(model_save_path):
        loaded_paras = torch.load(model_save_path)
        model.load_state_dict(loaded_paras)  # 用本地已有模型来重新初始化网络权重参数
        model.eval() # 注意不要忘记
    with torch.no_grad():
        acc_sum, n = 0.0, 0
        for x, y in data_iter:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            acc_sum += (logits.argmax(1) == y).float().sum().item()
            n += len(y)
        print("Accuracy in test data is :", acc_sum / n)

在上述代码中,4-7行便是用来载入本地模型参数,并用其覆盖网络模型中原有的参数。这样,便可以进行后续的推断工作:

Accuracy in test data is : 0.8851

2.3 载入模型进行训练

在介绍完模型的保存与复用之后,对于网络的追加训练就很简单了。最简便的一种方式就是在训练过程中只保存网络权重,然后在后续进行追加训练时只载入网络权重参数初始化网络进行训练即可,示例如下(完整代码参见[2]):

 def train(self):
        #......
        model_save_path = os.path.join(self.model_save_dir, 'model.pt')
        if os.path.exists(model_save_path):
            loaded_paras = torch.load(model_save_path)
            self.model.load_state_dict(loaded_paras)
            print("#### 成功载入已有模型,进行追加训练...")
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)  # 定义优化器
       #......
        for epoch in range(self.epochs):
            for i, (x, y) in enumerate(train_iter):
                x, y = x.to(device), y.to(device)
                logits = self.model(x)
                # ......
            print("Epochs[{}/{}]--acc on test {:.4}".format(epoch, self.epochs,
                                              self.evaluate(test_iter, self.model, device)))
            torch.save(self.model.state_dict(), model_save_path)

这样,便完成了模型的追加训练:

#### 成功载入已有模型,进行追加训练...
Epochs[0/5]---batch[938/0]---acc 0.9062---loss 0.2926
Epochs[0/5]---batch[938/100]---acc 0.9375---loss 0.1598
......

除此之外,你也可以在保存参数的时候,将优化器参数、损失值等一同保存下来,然后在恢复模型的时候连同其它参数一起恢复,示例如下:

model_save_path = os.path.join(model_save_dir, 'model.pt')
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, model_save_path)

载入方式如下:

checkpoint = torch.load(model_save_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

2.4 载入模型进行迁移

(1)定义新模型

到目前为止,对于前面两种应用场景的介绍就算完成了,可以发现总体上并不复杂。但是对于第3中场景的应用来说就会略微复杂一点。

假设现在有一个LeNet6网络模型,它是在LeNet5的基础最后多加了一个全连接层,其定义如下:

class LeNet6(nn.Module):
    def __init__(self, ):
        super(LeNet6, self).__init__()
        self.conv = nn.Sequential(  # [n,1,28,28]
            nn.Conv2d(1, 6, 5, padding=2),  # in_channels, out_channels, kernel_size
            nn.ReLU(),  # [n,6,24,24]
            nn.MaxPool2d(2, 2),  # kernel_size, stride  [n,6,14,14]
            nn.Conv2d(6, 16, 5),  # [n,16,10,10]
            nn.ReLU(),
            nn.MaxPool2d(2, 2))  # [n,16,5,5]
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(16 * 5 * 5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, 64),
            nn.ReLU(),
            nn.Linear(64, 10) ) # 新加入的全连接层

接下来,我们需要将在LeNet5上训练得到的权重参数迁移到LeNet6网络中去。从上面LeNet6的定义可以发现,此时尽管只是多加了一个全连接层,但是倒数第2层参数的维度也发生了变换。因此,对于LeNet6来说只能复用LeNet5网络前面4层的权重参数。

(2)查看模型参数

在拿到一个模型参数后,首先我们可以将其载入,然查看相关参数的信息:

model_save_path = os.path.join('./MODEL', 'model.pt')
loaded_paras = torch.load(model_save_path)
for param_tensor in loaded_paras:
    print(param_tensor, "\t", loaded_paras[param_tensor].size())

#---- 可复用部分
conv.0.weight   torch.Size([6, 1, 5, 5])
conv.0.bias   torch.Size([6])
conv.3.weight   torch.Size([16, 6, 5, 5])
conv.3.bias   torch.Size([16])
fc.1.weight   torch.Size([120, 400])
fc.1.bias   torch.Size([120])
fc.3.weight   torch.Size([84, 120])
fc.3.bias   torch.Size([84])
#----- 不可复用部分
fc.5.weight   torch.Size([10, 84])
fc.5.bias   torch.Size([10])

同时,对于LeNet6网络的参数信息为:

model = LeNet6()
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())
#
conv.0.weight   torch.Size([6, 1, 5, 5])
conv.0.bias   torch.Size([6])
conv.3.weight   torch.Size([16, 6, 5, 5])
conv.3.bias   torch.Size([16])
fc.1.weight   torch.Size([120, 400])
fc.1.bias   torch.Size([120])
fc.3.weight   torch.Size([84, 120])
fc.3.bias   torch.Size([84])
#------ 新加入部分
fc.5.weight   torch.Size([64, 84])
fc.5.bias   torch.Size([64])
fc.7.weight   torch.Size([10, 64])
fc.7.bias   torch.Size([10])

在理清楚了新旧模型的参数后,下面就可以将LeNet5中我们需要的参数给取出来,然后再换到LeNet6的网络中。

(3)模型迁移

虽然本地载入的模型参数(上面的loaded_paras)和模型初始化后的参数(上面的model.state_dict())都是一个字典的形式,但是我们并不能够直接改变model.state_dict()中的权重参数。这里需要先构造一个state_dict然后通过model.load_state_dict()方法来重新初始化网络中的参数。

同时,在这个过程中我们需要筛选掉本地模型中不可复用的部分,具体代码如下:

def para_state_dict(model, model_save_dir):
    state_dict = deepcopy(model.state_dict())
    model_save_path = os.path.join(model_save_dir, 'model.pt')
    if os.path.exists(model_save_path):
        loaded_paras = torch.load(model_save_path)
        for key in state_dict:  # 在新的网络模型中遍历对应参数
            if key in loaded_paras and state_dict[key].size() == loaded_paras[key].size():
                print("成功初始化参数:", key)
                state_dict[key] = loaded_paras[key]
    return state_dict

在上述代码中,第2行的作用是先拷贝网络中(LeNet6)原有的参数;第6-9行则是用本地的模型参数(LeNet5)中可以复用的替换掉LeNet6中的对应部分,其中第7行就是判断可用的条件。同时需要注意的是在不同的情况下筛选的方式可能不一样,因此具体情况需要具体分析,但是整体逻辑是一样的。

最后,我们只需要在模型训练之前调用该函数,然后重新初始化LeNet6中的部分权重参数即可[2]:

state_dict = para_state_dict(self.model, self.model_save_dir)
self.model.load_state_dict(state_dict)

训练结果如下:

成功初始化参数: conv.0.weight
成功初始化参数: conv.0.bias
成功初始化参数: conv.3.weight
成功初始化参数: conv.3.bias
成功初始化参数: fc.1.weight
成功初始化参数: fc.1.bias
成功初始化参数: fc.3.weight
成功初始化参数: fc.3.bias
#### 成功载入已有模型,进行追加训练...
Epochs[0/5]---batch[938/0]---acc 0.1094---loss 2.512
Epochs[0/5]---batch[938/100]---acc 0.9375---loss 0.2141
Epochs[0/5]---batch[938/200]---acc 0.9219---loss 0.2729
Epochs[0/5]---batch[938/300]---acc 0.8906---loss 0.2958
......
Epochs[0/5]---batch[938/900]---acc 0.8906---loss 0.2828
Epochs[0/5]--acc on test 0.8808

可以发现,在大约100个batch之后,模型的准确率就提升上来了。

3 总结

在本篇文章中,笔者首先介绍了模型复用的几种典型场景;然后介绍了如何查看Pytorch模型中的相关参数信息;接着介绍了如何载入模型、如何进行追加训练以及进行模型的迁移学习等。有了这部分内容的铺垫,在后续介绍类似BERT这样的预训练模型载入时就会容易很多了。

到此这篇关于Pytorch中模型的保存与迁移的文章就介绍到这了,更多相关Pytorch模型内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • Pytorch训练模型得到输出后计算F1-Score 和AUC的操作

    1.计算F1-Score 对于二分类来说,假设batch size 大小为64的话,那么模型一个batch的输出应该是torch.size([64,2]),所以首先做的是得到这个二维矩阵的每一行的最大索引值,然后添加到一个列表中,同时把标签也添加到一个列表中,最后使用sklearn中计算F1的工具包进行计算,代码如下 import numpy as np import sklearn.metrics import f1_score prob_all = [] lable_all = [] for

  • pytorch加载预训练模型与自己模型不匹配的解决方案

    pytorch中如果自己搭建网络并且加载别人的与训练模型的话,如果模型和参数不严格匹配,就可能会出问题,接下来记录一下我的解决方法. 两个有序字典找不同 模型的参数和pth文件的参数都是有序字典(OrderedDict),把字典中的键转为列表就可以在for循环里迭代找不同了. model = ResNet18(1) model_dict1 = torch.load('resnet18.pth') model_dict2 = model.state_dict() model_list1 = lis

  • 解决Pytorch修改预训练模型时遇到key不匹配的情况

    一.Pytorch修改预训练模型时遇到key不匹配 最近想着修改网络的预训练模型vgg.pth,但是发现当我加载预训练模型权重到新建的模型并保存之后. 在我使用新赋值的网络模型时出现了key不匹配的问题 #加载后保存(未修改网络) base_weights = torch.load(args.save_folder + args.basenet) ssd_net.vgg.load_state_dict(base_weights) torch.save(ssd_net.state_dict(),

  • pytorch 预训练模型读取修改相关参数的填坑问题

    pytorch 预训练模型读取修改相关参数的填坑 修改部分层,仍然调用之前的模型参数. resnet = resnet50(pretrained=False) resnet.load_state_dict(torch.load(args.predir)) res_conv31 = Bottleneck_dilated(1024, 256,dilated_rate = 2) print("---------------------",res_conv31) print("---

  • pytorch 使用半精度模型部署的操作

    背景 pytorch作为深度学习的计算框架正得到越来越多的应用. 我们除了在模型训练阶段应用外,最近也把pytorch应用在了部署上. 在部署时,为了减少计算量,可以考虑使用16位浮点模型,而训练时涉及到梯度计算,需要使用32位浮点,这种精度的不一致经过测试,模型性能下降有限,可以接受. 但是推断时计算量可以降低一半,同等计算资源下,并发度可提升近一倍 具体方法 在pytorch中,一般模型定义都继承torch.nn.Moudle,torch.nn.Module基类的half()方法会把所有参数

  • 关于Pytorch中模型的保存与迁移问题

    目录 1 引言 2 模型的保存与复用 2.1 查看网络模型参数 2.2 载入模型进行推断 2.3 载入模型进行训练 2.4 载入模型进行迁移 3 总结 1 引言 各位朋友大家好,欢迎来到月来客栈.今天要和大家介绍的内容是如何在Pytorch框架中对模型进行保存和载入.以及模型的迁移和再训练.一般来说,最常见的场景就是模型完成训练后的推断过程.一个网络模型在完成训练后通常都需要对新样本进行预测,此时就只需要构建模型的前向传播过程,然后载入已训练好的参数初始化网络即可. 第2个场景就是模型的再训练过

  • Pytorch提取模型特征向量保存至csv的例子

    Pytorch提取模型特征向量 # -*- coding: utf-8 -*- """ dj """ import torch import torch.nn as nn import os from torchvision import models, transforms from torch.autograd import Variable import numpy as np from PIL import Image import to

  • pytorch 中的重要模块化接口nn.Module的使用

    torch.nn 是专门为神经网络设计的模块化接口,nn构建于autgrad之上,可以用来定义和运行神经网络 nn.Module 是nn中重要的类,包含网络各层的定义,以及forward方法 查看源码 初始化部分: def __init__(self): self._backend = thnn_backend self._parameters = OrderedDict() self._buffers = OrderedDict() self._backward_hooks = Ordered

  • Pytorch中实现只导入部分模型参数的方式

    我们在做迁移学习,或者在分割,检测等任务想使用预训练好的模型,同时又有自己修改之后的结构,使得模型文件保存的参数,有一部分是不需要的(don't expected).我们搭建的网络对保存文件来说,有一部分参数也是没有的(missed).如果依旧使用torch.load(model.state_dict())的办法,就会出现 xxx expected,xxx missed类似的错误.那么在这种情况下,该如何导入模型呢? 好在Pytorch中的模型参数使用字典保存的,键是参数的名称,值是参数的具体数

  • 解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题

    背景 在公司用多卡训练模型,得到权值文件后保存,然后回到实验室,没有多卡的环境,用单卡训练,加载模型时出错,因为单卡机器上,没有使用DataParallel来加载模型,所以会出现加载错误. 原因 DataParallel包装的模型在保存时,权值参数前面会带有module字符,然而自己在单卡环境下,没有用DataParallel包装的模型权值参数不带module.本质上保存的权值文件是一个有序字典. 解决方法 1.在单卡环境下,用DataParallel包装模型. 2.自己重写Load函数,灵活.

  • pytorch模型的保存和加载、checkpoint操作

    其实之前笔者写代码的时候用到模型的保存和加载,需要用的时候就去度娘搜一下大致代码,现在有时间就来整理下整个pytorch模型的保存和加载,开始学习把~ pytorch的模型和参数是分开的,可以分别保存或加载模型和参数.所以pytorch的保存和加载对应存在两种方式: 1. 直接保存加载模型 (1)保存和加载整个模型 # 保存模型 torch.save(model, 'model.pth\pkl\pt') #一般形式torch.save(net, PATH) # 加载模型 model = torc

  • PyTorch深度学习模型的保存和加载流程详解

    一.模型参数的保存和加载 torch.save(module.state_dict(), path):使用module.state_dict()函数获取各层已经训练好的参数和缓冲区,然后将参数和缓冲区保存到path所指定的文件存放路径(常用文件格式为.pt..pth或.pkl). torch.nn.Module.load_state_dict(state_dict):从state_dict中加载参数和缓冲区到Module及其子类中 . torch.nn.Module.state_dict()函数

  • PyTorch模型的保存与加载方法实例

    目录 模型的保存与加载 保存和加载模型参数 保存和加载模型参数与结构 总结 模型的保存与加载 首先,需要导入两个包 import torch import torchvision.models as models 保存和加载模型参数 PyTorch模型将学习到的参数存储在一个内部状态字典中,叫做state_dict.这可以通过torch.save方法来实现.我们导入预训练好的VGG16模型,并将其保存.我们将state_dict字典保存在model_weights.pth文件中. model =

  • 在pytorch中如何查看模型model参数parameters

    目录 pytorch查看模型model参数parameters pytorch查看模型参数总结 1:DNN_printer 2:parameters 3:get_model_complexity_info() 4:torchstat pytorch查看模型model参数parameters 示例1:pytorch自带的faster r-cnn模型 import torch import torchvision model = torchvision.models.detection.faster

  • 在Pytorch中计算自己模型的FLOPs方式

    https://github.com/Lyken17/pytorch-OpCounter 安装方法很简单: pip install thop 基本用法: from torchvision.models import resnet50from thop import profile model = resnet50() flops, params = profile(model, input_size=(1, 3, 224,224)) 对自己的module进行特别的计算: class YourMo

随机推荐