Pytorch中关于model.eval()的作用及分析

目录
  • model.eval()的作用及分析
    • 结论
  • Pytorch踩坑之model.eval()问题
    • 比较常见的有两方面的原因
    • 1) data
    • 2)model.state_dict()
    • model.eval()   vs   torch.no_grad()
  • 总结

model.eval()的作用及分析

  • model.eval() 作用等同于 self.train(False)

简而言之,就是评估模式。而非训练模式。

在评估模式下,batchNorm层,dropout层等用于优化训练而添加的网络层会被关闭,从而使得评估时不会发生偏移。

结论

在对模型进行评估时,应该配合使用with torch.no_grad() 与 model.eval():

    loop:
        model.train()    # 切换至训练模式
        train……
        model.eval()
        with torch.no_grad():
            Evaluation
    end loop

Pytorch踩坑之model.eval()问题

最近在写代码时遇到一个问题,原本训练好的模型,加载进来进行inference准确率直接掉了5个点,这简直不能忍啊~下意识地感知到我肯定又在哪里写了bug了~~~于是开始到处排查,从model load到data load,最终在一个被我封装好的module的犄角旮旯里找到了问题,于是顺便就在这里总结一下,避免以后再犯。

对于训练好的模型加载进来准确率和原先的不符,

比较常见的有两方面的原因

  • data
  • model.state_dict()

1) data

数据方面,检查前后两次加载的data有没有发生变化。首先检查 transforms.Normalize 使用的均值和方差是否和训练时相同;另外检查在这个过程中数据是否经过了存储形式的改变,这有可能会带来数据精度的变化导致一定的信息丢失。比如我过用的其中一个数据集,原先将图片存储成向量形式,但其对应的是“png”格式的数据(后来在原始文件中发现了相应的描述。),而我进行了一次data-to-img操作,将向量转换成了“jpg”形式,这时加载进来便造成了掉点。

2)model.state_dict()

第一方面造成的掉点一般不会太严重,第二方面造成的掉点就比较严重了,一旦模型的参数加载错了,那就误差大了。

如果是参数没有正确加载进来则比较容易发现,这时准确率非常低,几乎等于瞎猜。

而我这次遇到的情况是,准确率并不是特别低,只掉了几个点,检查了多次,均显示模型参数已经成功加载了。后来仔细查看后发现在其中一次调用模型进行inference时,忘了写 ‘model.eval()’,造成了模型的参数发生变化,再次调用则出现了掉点。于是又回顾了一下model.eval()和model.train()的具体作用。如下:

model.train() 和 model.eval() 一般在模型训练和评价的时候会加上这两句,主要是针对由于model 在训练时和评价时 Batch Normalization 和 Dropout 方法模式不同:

  • a) model.eval(),不启用 BatchNormalization 和 Dropout。此时pytorch会自动把BN和DropOut固定住,不会取平均,而是用训练好的值。不然的话,一旦test的batch_size过小,很容易就会因BN层导致模型performance损失较大;
  • b) model.train() :启用 BatchNormalization 和 Dropout。 在模型测试阶段使用model.train() 让model变成训练模式,此时 dropout和batch normalization的操作在训练q起到防止网络过拟合的问题。

因此,在使用PyTorch进行训练和测试时一定要记得把实例化的model指定train/eval。

model.eval()   vs   torch.no_grad()

虽然二者都是eval的时候使用,但其作用并不相同:

model.eval() 负责改变batchnorm、dropout的工作方式,如在eval()模式下,dropout是不工作的。

见下方代码:

  import torch
  import torch.nn as nn
 
  drop = nn.Dropout()
  x = torch.ones(10)
  
  # Train mode   
  drop.train()
  print(drop(x)) # tensor([2., 2., 0., 2., 2., 2., 2., 0., 0., 2.])   
  
  # Eval mode   
  drop.eval()
  print(drop(x)) # tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

torch.no_grad() 负责关掉梯度计算,节省eval的时间。

只进行inference时,model.eval()是必须使用的,否则会影响结果准确性。 而torch.no_grad()并不是强制的,只影响运行效率。

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • 解决Pytorch中的神坑:关于model.eval的问题

    有时候使用Pytorch训练完模型,在测试数据上面得到的结果令人大跌眼镜. 这个时候需要检查一下定义的Model类中有没有 BN 或 Dropout 层,如果有任何一个存在 那么在测试之前需要加入一行代码: #model是实例化的模型对象 model = model.eval() 表示将模型转变为evaluation(测试)模式,这样就可以排除BN和Dropout对测试的干扰. 因为BN和Dropout在训练和测试时是不同的: 对于BN,训练时通常采用mini-batch,所以每一批中的mean

  • 聊聊pytorch测试的时候为何要加上model.eval()

    Do need to use model.eval() when I test? Sure, Dropout works as a regularization for preventing overfitting during training. It randomly zeros the elements of inputs in Dropout layer on forward call. It should be disabled during testing since you may

  • pytorch:model.train和model.eval用法及区别详解

    使用PyTorch进行训练和测试时一定注意要把实例化的model指定train/eval,eval()时,框架会自动把BN和DropOut固定住,不会取平均,而是用训练好的值,不然的话,一旦test的batch_size过小,很容易就会被BN层导致生成图片颜色失真极大!!!!!! Class Inpaint_Network() ...... Model = Inpaint_Nerwoek() #train: Model.train(mode=True) ..... #test: Model.ev

  • Pytorch中的modle.train,model.eval,with torch.no_grad解读

    目录 modle.train,model.eval,with torch.no_grad解读 model.eval()与torch.no_grad()的作用 model.eval() torch.no_grad() 异同 总结 modle.train,model.eval,with torch.no_grad解读 1. 最近在学习pytorch过程中遇到了几个问题 不理解为什么在训练和测试函数中model.eval(),和model.train()的区别,经查阅后做如下整理 一般情况下,我们训练

  • pytorch中的model.eval()和BN层的使用

    看代码吧~ class ConvNet(nn.module): def __init__(self, num_class=10): super(ConvNet, self).__init__() self.layer1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2))

  • pytorch掉坑记录:model.eval的作用说明

    训练完train_datasets之后,model要来测试样本了.在model(test_datasets)之前,需要加上model.eval(). 否则的话,有输入数据,即使不训练,它也会改变权值. 这是model中含有batch normalization层所带来的的性质. 在做one classification的时候,训练集和测试集的样本分布是不一样的,尤其需要注意这一点. 补充知识:pytorch测试的时候为何要加上model.eval() Do need to use model.e

  • Pytorch中关于model.eval()的作用及分析

    目录 model.eval()的作用及分析 结论 Pytorch踩坑之model.eval()问题 比较常见的有两方面的原因 1) data 2)model.state_dict() model.eval()   vs   torch.no_grad() 总结 model.eval()的作用及分析 model.eval() 作用等同于 self.train(False) 简而言之,就是评估模式.而非训练模式. 在评估模式下,batchNorm层,dropout层等用于优化训练而添加的网络层会被关

  • pytorch中的model=model.to(device)使用说明

    这代表将模型加载到指定设备上. 其中,device=torch.device("cpu")代表的使用cpu,而device=torch.device("cuda")则代表的使用GPU. 当我们指定了设备之后,就需要将模型加载到相应设备中,此时需要使用model=model.to(device),将模型加载到相应的设备中. 将由GPU保存的模型加载到CPU上. 将torch.load()函数中的map_location参数设置为torch.device('cpu')

  • jquery中animate的stop()方法作用实例分析

    本文实例分析了jquery中animate的stop()方法作用.分享给大家供大家参考.具体分析如下: 这里以一个视频中的代码段告诉你stop()的作用: 代码如下: <style type="text/css"> ul li{ width:50px; height:30px; background:#333; margin-bottom:10px; color:#fff;} </style> <ul id="flyul"> &l

  • Python中逗号的三种作用实例分析

    本文实例讲述了Python中逗号的三种作用.分享给大家供大家参考.具体分析如下: 最近研究python  遇到个逗号的问题 一直没弄明白 今天总算搞清楚了 1.逗号在参数传递中的使用: 这种情况不多说  没有什么不解的地方 就是形参或者实参传递的时候参数之间的逗号 例如def  abc(a,b)或者abc(1,2) 2.逗号在类型转化中的使用 主要是元组的转换 例如: >>> a=11 >>> b=(a) >>> b 11 >>> b

  • 详解model.train()和model.eval()两种模式的原理与用法

    一.两种模式 pytorch可以给我们提供两种方式来切换训练和评估(推断)的模式,分别是:model.train() 和 model.eval(). 一般用法是:在训练开始之前写上 model.trian() ,在测试时写上 model.eval() . 二.功能 1. model.train() 在使用 pytorch 构建神经网络的时候,训练过程中会在程序上方添加一句model.train(),作用是 启用 batch normalization 和 dropout . 如果模型中有BN层(

  • Pytorch中的gather使用方法

    官方说明 gather可以对一个Tensor进行聚合,声明为:torch.gather(input, dim, index, out=None) → Tensor 一般来说有三个参数:输入的变量input.指定在某一维上聚合的dim.聚合的使用的索引index,输出为Tensor类型的结果(index必须为LongTensor类型). #参数介绍: input (Tensor) – The source tensor dim (int) – The axis along which to ind

随机推荐