浅谈pytorch中的BN层的注意事项

最近修改一个代码的时候,当使用网络进行推理的时候,发现每次更改测试集的batch size大小竟然会导致推理结果不同,甚至产生错误结果,后来发现在网络中定义了BN层,BN层在训练过程中,会将一个Batch的中的数据转变成正太分布,在推理过程中使用训练过程中的参数对数据进行处理,然而网络并不知道你是在训练还是测试阶段,因此,需要手动的加上,需要在测试和训练阶段使用如下函数。

model.train() or model.eval()

BN类的定义见pytorch中文参考文档

补充知识:关于pytorch中BN层(具体实现)的一些小细节

最近在做目标检测,需要把训好的模型放到嵌入式设备上跑前向,因此得把各种层的实现都用C手撸一遍,,,此为背景。

其他层没什么好说的,但是BN层这有个小坑。pytorch在打印网络参数的时候,只打出weight和bias这两个参数。咦,说好的BN层有四个参数running_mean、running_var 、gamma 、beta的呢?一开始我以为是pytorch把BN层的计算简化成weight * X + bias,但马上反应过来应该没这么简单,因为pytorch中只有可学习的参数才称为parameter。上网找了一些资料但都没有说到这么细的,毕竟大部分用户使用时只要模型能跑起来就行了,,,于是开始看BN层有哪些属性,果然发现了熟悉的running_mean和running_var,原来pytorch的BN层实现并没有不同。这里吐个槽:为啥要把gamma和beta改叫weight、bias啊,很有迷惑性的好不好,,,

扯了这么多,干脆捋一遍pytorch里BN层的具体实现过程,帮自己理清思路,也可以给大家提供参考。再吐槽一下,在网上搜“pytorch bn层”出来的全是关于这一层怎么用的、初始化时要输入哪些参数,没找到一个pytorch中BN层是怎么实现的,,,

众所周知,BN层的输出Y与输入X之间的关系是:Y = (X - running_mean) / sqrt(running_var + eps) * gamma + beta,此不赘言。其中gamma、beta为可学习参数(在pytorch中分别改叫weight和bias),训练时通过反向传播更新;而running_mean、running_var则是在前向时先由X计算出mean和var,再由mean和var以动量momentum来更新running_mean和running_var。所以在训练阶段,running_mean和running_var在每次前向时更新一次;在测试阶段,则通过net.eval()固定该BN层的running_mean和running_var,此时这两个值即为训练阶段最后一次前向时确定的值,并在整个测试阶段保持不变。

以上这篇浅谈pytorch中的BN层的注意事项就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • Pytorch之卷积层的使用详解

    1.简介(torch.nn下的) 卷积层主要使用的有3类,用于处理不同维度的数据 参数 Parameters: in_channels(int) – 输入信号的通道 out_channels(int) – 卷积产生的通道 kerner_size(int or tuple) - 卷积核的尺寸 stride(int or tuple, optional) - 卷积步长 padding (int or tuple, optional)- 输入的每一条边补充0的层数 dilation(int or tu

  • pytorch之添加BN的实现

    pytorch之添加BN层 批标准化 模型训练并不容易,特别是一些非常复杂的模型,并不能非常好的训练得到收敛的结果,所以对数据增加一些预处理,同时使用批标准化能够得到非常好的收敛结果,这也是卷积网络能够训练到非常深的层的一个重要原因. 数据预处理 目前数据预处理最常见的方法就是中心化和标准化,中心化相当于修正数据的中心位置,实现方法非常简单,就是在每个特征维度上减去对应的均值,最后得到 0 均值的特征.标准化也非常简单,在数据变成 0 均值之后,为了使得不同的特征维度有着相同的规模,可以除以标准

  • pytorch 预训练层的使用方法

    pytorch 预训练层的使用方法 将其他地方训练好的网络,用到新的网络里面 加载预训练网络 1.原先已经训练好一个网络 AutoEncoder_FC() 2.首先加载该网络,读取其存储的参数 3.设置一个参数集 cnnpre = AutoEncoder_FC() cnnpre.load_state_dict(torch.load('autoencoder_FC.pkl')['state_dict']) cnnpre_dict =cnnpre.state_dict() 加载新网络 1.设置新的网

  • 浅谈pytorch中的BN层的注意事项

    最近修改一个代码的时候,当使用网络进行推理的时候,发现每次更改测试集的batch size大小竟然会导致推理结果不同,甚至产生错误结果,后来发现在网络中定义了BN层,BN层在训练过程中,会将一个Batch的中的数据转变成正太分布,在推理过程中使用训练过程中的参数对数据进行处理,然而网络并不知道你是在训练还是测试阶段,因此,需要手动的加上,需要在测试和训练阶段使用如下函数. model.train() or model.eval() BN类的定义见pytorch中文参考文档 补充知识:关于pyto

  • 浅谈Pytorch中的torch.gather函数的含义

    pytorch中的gather函数 pytorch比tensorflow更加编程友好,所以准备用pytorch试着做最近要做的一些实验. 立个flag开始学习pytorch,新开一个分类整理学习pytorch中的一些踩到的泥坑. 今天刚开始接触,读了一下documentation,写一个一开始每太搞懂的函数gather b = torch.Tensor([[1,2,3],[4,5,6]]) print b index_1 = torch.LongTensor([[0,1],[2,0]]) ind

  • 浅谈Pytorch中的自动求导函数backward()所需参数的含义

    正常来说backward( )函数是要传入参数的,一直没弄明白backward需要传入的参数具体含义,但是没关系,生命在与折腾,咱们来折腾一下,嘿嘿. 对标量自动求导 首先,如果out.backward()中的out是一个标量的话(相当于一个神经网络有一个样本,这个样本有两个属性,神经网络有一个输出)那么此时我的backward函数是不需要输入任何参数的. import torch from torch.autograd import Variable a = Variable(torch.Te

  • 浅谈keras中的Merge层(实现层的相加、相减、相乘实例)

    [题目]keras中的Merge层(实现层的相加.相减.相乘) 详情请参考: Merge层 一.层相加 keras.layers.Add() 添加输入列表的图层. 该层接收一个相同shape列表张量,并返回它们的和,shape不变. Example import keras input1 = keras.layers.Input(shape=(16,)) x1 = keras.layers.Dense(8, activation='relu')(input1) input2 = keras.la

  • 浅谈PyTorch中in-place operation的含义

    in-place operation在pytorch中是指改变一个tensor的值的时候,不经过复制操作,而是直接在原来的内存上改变它的值.可以把它成为原地操作符. 在pytorch中经常加后缀"_"来代表原地in-place operation,比如说.add_() 或者.scatter().python里面的+=,*=也是in-place operation. 下面是正常的加操作,执行结束加操作之后x的值没有发生变化: import torch x=torch.rand(2) #t

  • 浅谈pytorch中torch.max和F.softmax函数的维度解释

    在利用torch.max函数和F.Ssoftmax函数时,对应该设置什么维度,总是有点懵,遂总结一下: 首先看看二维tensor的函数的例子: import torch import torch.nn.functional as F input = torch.randn(3,4) print(input) tensor([[-0.5526, -0.0194, 2.1469, -0.2567], [-0.3337, -0.9229, 0.0376, -0.0801], [ 1.4721, 0.1

  • 浅谈Java中GuavaCache返回Null的注意事项

    Guava在实际的Java后端项目中应用的场景还是比较多的,比如限流,缓存,容器操作之类的,有挺多实用的工具类,这里记录一下,在使用GuavaCache,返回null的一个问题 I. 常见使用姿势 @Test public void testGuava() { LoadingCache<String, String> cache = CacheBuilder.newBuilder().build(new CacheLoader<String, String>() { @Overri

  • 浅谈pytorch中的nn.Sequential(*net[3: 5])是啥意思

    看到代码里面有这个 1 class ResNeXt101(nn.Module): 2 def __init__(self): 3 super(ResNeXt101, self).__init__() 4 net = resnext101() # print(os.getcwd(), net) 5 net = list(net.children()) # net.children()得到resneXt 的表层网络 # for i, value in enumerate(net): # print(

  • 浅谈Pytorch中autograd的若干(踩坑)总结

    关于Variable和Tensor 旧版本的Pytorch中,Variable是对Tensor的一个封装:在Pytorch大于v0.4的版本后,Varible和Tensor合并了,意味着Tensor可以像旧版本的Variable那样运行,当然新版本中Variable封装仍旧可以用,但是对Varieble操作返回的将是一个Tensor. import torch as t from torch.autograd import Variable a = t.ones(3,requires_grad=

  • 浅谈pytorch中stack和cat的及to_tensor的坑

    初入计算机视觉遇到的一些坑 1.pytorch中转tensor x=np.random.randint(10,100,(10,10,10)) x=TF.to_tensor(x) print(x) 这个函数会对输入数据进行自动归一化,比如有时候我们需要将0-255的图片转为numpy类型的数据,则会自动转为0-1之间 2.stack和cat之间的差别 stack x=torch.randn((1,2,3)) y=torch.randn((1,2,3)) z=torch.stack((x,y))#默

随机推荐