详解model.train()和model.eval()两种模式的原理与用法

一、两种模式

pytorch可以给我们提供两种方式来切换训练和评估(推断)的模式,分别是:model.train() 和 model.eval()。

一般用法是:在训练开始之前写上 model.trian() ,在测试时写上 model.eval() 。

二、功能

1. model.train()

在使用 pytorch 构建神经网络的时候,训练过程中会在程序上方添加一句model.train(),作用是 启用 batch normalization 和 dropout 。

如果模型中有BN层(Batch Normalization)和 Dropout ,需要在 训练时 添加 model.train()。

model.train() 是保证 BN 层能够用到 每一批数据 的均值和方差。对于 Dropout,model.train() 是 随机取一部分 网络连接来训练更新参数。

2. model.eval()

model.eval()的作用是 不启用 Batch Normalization 和 Dropout。

如果模型中有 BN 层(Batch Normalization)和 Dropout,在 测试时 添加 model.eval()。

model.eval() 是保证 BN 层能够用 全部训练数据 的均值和方差,即测试过程中要保证 BN 层的均值和方差不变。对于 Dropout,model.eval() 是利用到了 所有 网络连接,即不进行随机舍弃神经元。

为什么测试时要用 model.eval() ?

训练完 train 样本后,生成的模型 model 要用来测试样本了。在 model(test) 之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是 model 中含有 BN 层和 Dropout 所带来的的性质。

eval() 时,pytorch 会自动把 BN 和 DropOut 固定住,不会取平均,而是用训练好的值。
不然的话,一旦 test 的 batch_size 过小,很容易就会被 BN 层导致生成图片颜色失真极大。
eval() 在非训练的时候是需要加的,没有这句代码,一些网络层的值会发生变动,不会固定,你神经网络每一次生成的结果也是不固定的,生成质量可能好也可能不好。

也就是说,测试过程中使用model.eval(),这时神经网络会 沿用 batch normalization 的值,而并 不使用 dropout。

3. 总结与对比

如果模型中有 BN 层(Batch Normalization)和 Dropout,需要在训练时添加 model.train(),在测试时添加 model.eval()。

其中 model.train() 是保证 BN 层用每一批数据的均值和方差,而 model.eval() 是保证 BN 用全部训练数据的均值和方差;

而对于 Dropout,model.train() 是随机取一部分网络连接来训练更新参数,而 model.eval() 是利用到了所有网络连接。

三、Dropout 简介

dropout 常常用于抑制过拟合。

设置Dropout时,torch.nn.Dropout(0.5),这里的 0.5 是指该层(layer)的神经元在每次迭代训练时会随机有 50% 的可能性被丢弃(失活),不参与训练。也就是将上一层数据减少一半传播。

到此这篇关于详解model.train()和model.eval()两种模式的原理与用法的文章就介绍到这了,更多相关model.train()和model.eval()原理用法内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • pytorch中的model.eval()和BN层的使用

    看代码吧~ class ConvNet(nn.module): def __init__(self, num_class=10): super(ConvNet, self).__init__() self.layer1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2))

  • Pytorch中关于model.eval()的作用及分析

    目录 model.eval()的作用及分析 结论 Pytorch踩坑之model.eval()问题 比较常见的有两方面的原因 1) data 2)model.state_dict() model.eval()   vs   torch.no_grad() 总结 model.eval()的作用及分析 model.eval() 作用等同于 self.train(False) 简而言之,就是评估模式.而非训练模式. 在评估模式下,batchNorm层,dropout层等用于优化训练而添加的网络层会被关

  • 聊聊pytorch测试的时候为何要加上model.eval()

    Do need to use model.eval() when I test? Sure, Dropout works as a regularization for preventing overfitting during training. It randomly zeros the elements of inputs in Dropout layer on forward call. It should be disabled during testing since you may

  • 详解model.train()和model.eval()两种模式的原理与用法

    一.两种模式 pytorch可以给我们提供两种方式来切换训练和评估(推断)的模式,分别是:model.train() 和 model.eval(). 一般用法是:在训练开始之前写上 model.trian() ,在测试时写上 model.eval() . 二.功能 1. model.train() 在使用 pytorch 构建神经网络的时候,训练过程中会在程序上方添加一句model.train(),作用是 启用 batch normalization 和 dropout . 如果模型中有BN层(

  • 详解Nuxt内导航栏的两种实现方式

    方式一 | 通过嵌套路由实现 在pages页面根据nuxt的路由规则,建立页面 1. 创建文件目录及文件 根据规则,如果要创建子路由,子路由的文件夹名字,必须和父路由名字相同 所以,我们的文件夹也为index,index文件夹需要一个默认的页面不然nuxt的路由规则就不能正确匹配页面 一级路由是根路由 二级路由是index,user,默认进入index路由 下面是router页面自动生成的路由 { path: "/", component: _93624e48, children: [

  • 详解Git建立本地仓库的两种方法

    Git是一种分布式版本控制系统,通常这类系统都可以与若干远端代码进行交互.Git项目具有三个主要部分:工作区,暂存目录,暂存区,本地目录: 安装完Git后,要做的第一件事,就是设置用户名和邮件地址.每个Git提交都使用此信息,并且将它永久地烘焙到您开始创建的提交中: $ git config --global user.name "John Doe" $ git config --global user.email johndoe@example.com 之后我们可以建立一个本地仓库.

  • 详解查看Python解释器路径的两种方式

    进入python的安装目录, 查看python解释器 进入bin目录 # ls python(看一下是否有python解释器版本) # pwd (查看当前目录) 复制当前目录即可 1. 通过脚本查看 运行以下脚本,或者进入交互模式手动输入即可. import sys import os print('当前 Python 解释器路径:') print(sys.executable) r""" 当前 Python 解释器路径: C:\Users\jpch89\AppData\Lo

  • 详解shell中脚本参数传递的两种方式

    方式一:$0,$1,$2.. 采用$0,$1,$2..等方式获取脚本命令行传入的参数,值得注意的是,$0获取到的是脚本路径以及脚本名,后面按顺序获取参数,当参数超过10个时(包括10个),需要使用${10},${11}....才能获取到参数,但是一般很少会超过10个参数的情况. 1.1 示例:新建一个test.sh的文件 #!/bin/bash echo "脚本$0" echo "第一个参数$1" echo "第二个参数$2" 在shell中执行

  • 详解matplotlib中pyplot和面向对象两种绘图模式之间的关系

    matplotlib有两种绘图方式,一种是依托matplotlib.pyplot模块实现类似matlab绘图指令的绘图方式,一种是面向对象式绘图,依靠FigureCanvas(画布). Figure (图像). Axes (轴域) 等对象绘图. 这两种方式之间并不是完全独立的,而是通过某种机制进行了联结,pylot绘图模式其实隐式创建了面向对象模式的相关对象,其中的关键是matplotlib._pylab_helpers模块中的单例类Gcf,它的作用是追踪当前活动的画布及图像. 因此,可以说ma

  • 详解pygame捕获键盘事件的两种方式

    方式1:在pygame中使用pygame.event.get()方法捕获键盘事件,使用这个方式捕获的键盘事件必须要是按下再弹起才算一次. 示例示例: for event in pygame.event.get(): # 捕获键盘事件 if event.type == pygame.QUIT: # 判断按键类型 print("按下了退出按键") 方式2:在pygame中可以使用pygame.key.get_pressed()来返回所有按键元组,通过判断键盘常量,可以在元组中判断出那个键被

  • 详解Python实现图像分割增强的两种方法

    方法一 import random import numpy as np from PIL import Image, ImageOps, ImageFilter from skimage.filters import gaussian import torch import math import numbers import random class RandomVerticalFlip(object): def __call__(self, img): if random.random()

  • 详解springmvc 接收json对象的两种方式

    最近学习了springmvc 接收json对象的两种方式,现在整理出来,具体如下: 1.以实体类方式接收 前端 ajax 提交数据: function fAddObj() { var obj = {}; obj['objname'] = "obj"; obj['pid'] = 1 ; $.ajax({ url: 'admin/Obj/addObj.do', method: 'post', contentType: 'application/json', // 这句不加出现415错误:U

  • 详解Eclipse安装SVN插件的两种方法

    eclipse里安装SVN插件,一般来说,有两种方式: 直接下载SVN插件,将其解压到eclipse的对应目录里 使用eclipse 里Help菜单的"Install New Software",通过输入SVN地址,直接下载安装到eclipse里 第一种方式: 1.下载SVN插件 SVN插件下载地址及更新地址,你根据需要选择你需要的版本.现在最新是1.8.x Links for 1.8.x Release: Eclipse update site URL: http://subclip

随机推荐