对比分析BN和dropout在预测和训练时区别

目录
  • Batch Normalization
  • Dropout

Batch Normalization和Dropout是深度学习模型中常用的结构。

但BN和dropout在训练和测试时使用却不相同。

Batch Normalization

BN在训练时是在每个batch上计算均值和方差来进行归一化,每个batch的样本量都不大,所以每次计算出来的均值和方差就存在差异。预测时一般传入一个样本,所以不存在归一化,其次哪怕是预测一个batch,但batch计算出来的均值和方差是偏离总体样本的,所以通常是通过滑动平均结合训练时所有batch的均值和方差来得到一个总体均值和方差。

以tensorflow代码实现为例:

def bn_layer(self, inputs, training, name='bn', moving_decay=0.9, eps=1e-5):
        # 获取输入维度并判断是否匹配卷积层(4)或者全连接层(2)
        shape = inputs.shape
        param_shape = shape[-1]
        with tf.variable_scope(name):
            # 声明BN中唯一需要学习的两个参数,y=gamma*x+beta
            gamma = tf.get_variable('gamma', param_shape, initializer=tf.constant_initializer(1))
            beta  = tf.get_variable('beat', param_shape, initializer=tf.constant_initializer(0))
            # 计算当前整个batch的均值与方差
            axes = list(range(len(shape)-1))
            batch_mean, batch_var = tf.nn.moments(inputs , axes, name='moments')
            # 采用滑动平均更新均值与方差
            ema = tf.train.ExponentialMovingAverage(moving_decay, name="ema")
            def mean_var_with_update():
                ema_apply_op = ema.apply([batch_mean, batch_var])
                with tf.control_dependencies([ema_apply_op]):
                    return tf.identity(batch_mean), tf.identity(batch_var)
            # 训练时,更新均值与方差,测试时使用之前最后一次保存的均值与方差
            mean, var = tf.cond(tf.equal(training,True), mean_var_with_update,
                    lambda:(ema.average(batch_mean), ema.average(batch_var)))
            # 最后执行batch normalization
            return tf.nn.batch_normalization(inputs ,mean, var, beta, gamma, eps)

training参数可以通过tf.placeholder传入,这样就可以控制训练和预测时training的值。

self.training = tf.placeholder(tf.bool, name="training")

Dropout

Dropout在训练时会随机丢弃一些神经元,这样会导致输出的结果变小。而预测时往往关闭dropout,保证预测结果的一致性(不关闭dropout可能同一个输入会得到不同的输出,不过输出会服从某一分布。另外有些情况下可以不关闭dropout,比如文本生成下,不关闭会增大输出的多样性)。

为了对齐Dropout训练和预测的结果,通常有两种做法,假设dropout rate = 0.2。一种是训练时不做处理,预测时输出乘以(1 - dropout rate)。另一种是训练时留下的神经元除以(1 - dropout rate),预测时不做处理。以tensorflow为例。

x = tf.nn.dropout(x, self.keep_prob)
self.keep_prob = tf.placeholder(tf.float32, name="keep_prob")

tf.nn.dropout就是采用了第二种做法,训练时除以(1 - dropout rate),源码如下:

binary_tensor = math_ops.floor(random_tensor)
 ret = math_ops.div(x, keep_prob) * binary_tensor
 if not context.executing_eagerly():
   ret.set_shape(x.get_shape())
 return ret

binary_tensor就是一个mask tensor,即里面的值由0或1组成。keep_prob = 1 - dropout rate。

以上就是对比分析BN和dropout在预测和训练时区别的详细内容,更多关于BN与dropout预测训练对比的资料请关注我们其它相关文章!

(0)

相关推荐

  • PyTorch dropout设置训练和测试模式的实现

    看代码吧~ class Net(nn.Module): - model = Net() - model.train() # 把module设成训练模式,对Dropout和BatchNorm有影响 model.eval() # 把module设置为预测模式,对Dropout和BatchNorm模块有影响 补充:Pytorch遇到的坑--训练模式和测试模式切换 由于训练的时候Dropout和BN层起作用,每个batch BN层的参数不一样,dropout在训练时随机失效点具有随机性,所以训练和测试要

  • Pytorch之如何dropout避免过拟合

    一.做数据 二.搭建神经网络 三.训练 四.对比测试结果 注意:测试过程中,一定要注意模式切换 Pytorch的学习--过拟合 过拟合 过拟合是当数据量较小时或者输出结果过于依赖某些特定的神经元,训练神经网络训练会发生一种现象.出现这种现象的神经网络预测的结果并不具有普遍意义,其预测结果极不准确. 解决方法 1.增加数据量 2.L1,L2,L3-正规化,即在计算误差值的时候加上要学习的参数值,当参数改变过大时,误差也会变大,通过这种惩罚机制来控制过拟合现象 3.dropout正规化,在训练过程中

  • 解决BN和Dropout共同使用时会出现的问题

    BN与Dropout共同使用出现的问题 BN和Dropout单独使用都能减少过拟合并加速训练速度,但如果一起使用的话并不会产生1+1>2的效果,相反可能会得到比单独使用更差的效果. 相关的研究参考论文:Understanding the Disharmony between Dropout and Batch Normalization by Variance Shift 本论文作者发现理解 Dropout 与 BN 之间冲突的关键是网络状态切换过程中存在神经方差的(neural varianc

  • pytorch中的model.eval()和BN层的使用

    看代码吧~ class ConvNet(nn.module): def __init__(self, num_class=10): super(ConvNet, self).__init__() self.layer1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2))

  • 对比分析BN和dropout在预测和训练时区别

    目录 Batch Normalization Dropout Batch Normalization和Dropout是深度学习模型中常用的结构. 但BN和dropout在训练和测试时使用却不相同. Batch Normalization BN在训练时是在每个batch上计算均值和方差来进行归一化,每个batch的样本量都不大,所以每次计算出来的均值和方差就存在差异.预测时一般传入一个样本,所以不存在归一化,其次哪怕是预测一个batch,但batch计算出来的均值和方差是偏离总体样本的,所以通常是

  • AngularJS下对数组的对比分析

    Javascript不能直接用==或者===来判断两个数组是否相等,无论是相等还是全等都不行,以下两行JS代码都会返回false <script type="text/javascript"> alert([]==[]); alert([]===[]); </script> 要判断JS中的两个数组是否相同,需要先将数组转换为字符串,再作比较.以下两行代码将返回true <script type="text/javascript">

  • Perl与JS的对比分析(数组、哈希)

    上一篇列出了Perl中定义数组,对象的方式与JS的异同.这里继续补充数组,哈希的相关操作. 一.数组 可以对数组进行增删,插入.与JS不同的是这些函数都是全局的,JS则是挂在Array.prototype上. 1,对数组尾部的操作pop(删除最后的元素).push(在尾部添加) @goods = qw/pen pencil/; pop(@goods); # @goods 变成 (pen) push(@goods, 'brush'); # @goods 变为 (pen, brush) 在Perl中

  • 浅谈MySQL和Lucene索引的对比分析

    MySQL和Lucene都可以对数据构建索引并通过索引查询数据,一个是关系型数据库,一个是构建搜索引擎(Solr.ElasticSearch)的核心类库.两者的索引(index)有什么区别呢?以前写过一篇<Solr与MySQL查询性能对比>,只是简单的对比了下查询性能,对于内部原理却没有解释,本文简单分析下两者的索引区别. MySQL索引实现 在MySQL中,索引属于存储引擎级别的概念,不同存储引擎对索引的实现方式是不同的,本文主要讨论MyISAM和InnoDB两个存储引擎的索引实现方式. M

  • Oracle不同数据库间对比分析脚本

    正在看的ORACLE教程是:Oracle不同数据库间对比分析脚本. Oracle数据库开发应用中经常对数据库管理员有这样的需求,对比两个不同实例间某模式下对象的差异或者对比两个不同实例某模式下表定义的差异性,这在涉及到数据库软件的开发应用中是经常遇到的.一般数据库软件的开发都是首先在开发数据库上进行,开发到一定程度后,系统投入运行,此时软件处于维护阶段.针对在系统运行中遇到的错误.bug等,还有应用系统的升级,经常需要调整后台程序,数据库开发人员经常遇到这样一种尴尬的事情,维护到一定时期,开发库

  • 对比分析Django的Q查询及AngularJS的Datatables分页插件

    使用Q查询,首先要导入Q模块: from django.db.models import Q 可以组合使用&,|操作符用于多个Q的对象,产生一个新的Q对象,Q对象也可以用~操作符放在前面表示否定,如下例所示: if search: keywords_list = search.split(' ') query_list = [Q(status__icontains=get_success_fail_status(keyword)) if get_success_fail_keyword_stat

  • hibernate和mybatis对比分析

    第一章     Hibernate与MyBatis Hibernate 是当前最流行的O/R mapping框架,它出身于sf.net,现在已经成为Jboss的一部分. Mybatis 是另外一种优秀的O/R mapping框架.目前属于apache的一个子项目. MyBatis 参考资料官网:http://www.mybatis.org/core/zh/index.html Hibernate参考资料: http://docs.jboss.org/hibernate/core/3.6/refe

  • java原生序列化和Kryo序列化性能实例对比分析

    简介 最近几年,各种新的高效序列化方式层出不穷,不断刷新序列化性能的上限,最典型的包括: 专门针对Java语言的:Kryo,FST等等 跨语言的:Protostuff,ProtoBuf,Thrift,Avro,MsgPack等等 这些序列化方式的性能多数都显著优于hessian2(甚至包括尚未成熟的dubbo序列化).有鉴于此,我们为dubbo引入Kryo和FST这 两种高效Java序列化实现,来逐步取代hessian2.其中,Kryo是一种非常成熟的序列化实现,已经在Twitter.Group

  • php中随机函数mt_rand()与rand()性能对比分析

    本文实例对比分析了php中随机函数mt_rand()与rand()性能问题.分享给大家供大家参考.具体分析如下: 在php中mt_rand()和rand()函数都是可以随机生成一个纯数字的,他们都是需要我们设置好种子数据然后生成,那么mt_rand()和rand()那个性能会好一些呢,下面我们带着疑问来测试一下. 例子1. mt_rand() 范例,代码如下: 复制代码 代码如下: <?php echo mt_rand() . "n"; echo mt_rand() . &quo

随机推荐