对TensorFlow中的variables_to_restore函数详解

variables_to_restore函数,是TensorFlow为滑动平均值提供。之前,也介绍过通过使用滑动平均值可以让神经网络模型更加的健壮。我们也知道,其实在TensorFlow中,变量的滑动平均值都是由影子变量所维护的,如果你想要获取变量的滑动平均值需要获取的是影子变量而不是变量本身。

1、滑动平均值模型文件的保存

import tensorflow as tf

if __name__ == "__main__":
 v = tf.Variable(0.,name="v")
 #设置滑动平均模型的系数
 ema = tf.train.ExponentialMovingAverage(0.99)
 #设置变量v使用滑动平均模型,tf.all_variables()设置所有变量
 op = ema.apply([v])
 #获取变量v的名字
 print(v.name)
 #v:0
 #创建一个保存模型的对象
 save = tf.train.Saver()
 sess = tf.Session()
 #初始化所有变量
 init = tf.initialize_all_variables()
 sess.run(init)
 #给变量v重新赋值
 sess.run(tf.assign(v,10))
 #应用平均滑动设置
 sess.run(op)
 #保存模型文件
 save.save(sess,"./model.ckpt")
 #输出变量v之前的值和使用滑动平均模型之后的值
 print(sess.run([v,ema.average(v)]))
 #[10.0, 0.099999905]

上面的代码,是如何来保存一个滑动平均值的模型文件,之前有介绍过滑动平均值和模型文件的保存,所以这里就不再重复了。

2、滑动平均值模型文件的读取

 v = tf.Variable(1.,name="v")
 #定义模型对象
 saver = tf.train.Saver({"v/ExponentialMovingAverage":v})
 sess = tf.Session()
 saver.restore(sess,"./model.ckpt")
 print(sess.run(v))
 #0.0999999

对于模型文件的读取,在上一篇博客中有介绍过,这里特别需要注意的一个地方就是,在使用tf.train.Saver函数中,所传递的模型参数是{"v/ExponentialMovingAverage":v}而不是{"v":v},如果你使用的是后面的参数,那么你得到的结果将是10而不是0.09,那是因为后者获取的是变量本身而不是影子变量。是不是感觉使用这种方式来读取模型文件的时候,还需要输入一大串的变量名称。

3、variables_to_restore函数的使用

 v = tf.Variable(1.,name="v")
 #滑动模型的参数的大小并不会影响v的值
 ema = tf.train.ExponentialMovingAverage(0.99)
 print(ema.variables_to_restore())
 #{'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}
 sess = tf.Session()
 saver = tf.train.Saver(ema.variables_to_restore())
 saver.restore(sess,"./model.ckpt")
 print(sess.run(v))
 #0.0999999

通过使用variables_to_restore函数,可以使在加载模型的时候将影子变量直接映射到变量的本身,所以我们在获取变量的滑动平均值的时候只需要获取到变量的本身值而不需要去获取影子变量。

以上这篇对TensorFlow中的variables_to_restore函数详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • TensorFLow用Saver保存和恢复变量

    本文为大家分享了TensorFLow用Saver保存和恢复变量的具体代码,供大家参考,具体内容如下 建立文件tensor_save.py, 保存变量v1,v2的tensor到checkpoint files中,名称分别设置为v3,v4. import tensorflow as tf # Create some variables. v1 = tf.Variable(3, name="v1") v2 = tf.Variable(4, name="v2") # Cre

  • TensorFlow模型保存和提取的方法

    一.TensorFlow模型保存和提取方法 1. TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提取.tf.train.Saver对象saver的save方法将TensorFlow模型保存到指定路径中,saver.save(sess,"Model/model.ckpt") ,实际在这个文件目录下会生成4个人文件: checkpoint文件保存了一个录下多有的模型文件列表,model.ckpt.meta保存了TensorFlow计算图的结构信息,model

  • TensorFlow变量管理详解

    一.TensorFlow变量管理 1. TensorFLow还提供了tf.get_variable函数来创建或者获取变量,tf.variable用于创建变量时,其功能和tf.Variable基本是等价的.tf.get_variable中的初始化方法(initializer)的参数和tf.Variable的初始化过程也类似,initializer函数和tf.Variable的初始化方法是一一对应的,详见下表. tf.get_variable和tf.Variable最大的区别就在于指定变量名称的参数

  • Tensorflow之Saver的用法详解

    Saver的用法 1. Saver的背景介绍 我们经常在训练完一个模型之后希望保存训练的结果,这些结果指的是模型的参数,以便下次迭代的训练或者用作测试.Tensorflow针对这一需求提供了Saver类. Saver类提供了向checkpoints文件保存和从checkpoints文件中恢复变量的相关方法.Checkpoints文件是一个二进制文件,它把变量名映射到对应的tensor值 . 只要提供一个计数器,当计数器触发时,Saver类可以自动的生成checkpoint文件.这让我们可以在训练

  • 对TensorFlow中的variables_to_restore函数详解

    variables_to_restore函数,是TensorFlow为滑动平均值提供.之前,也介绍过通过使用滑动平均值可以让神经网络模型更加的健壮.我们也知道,其实在TensorFlow中,变量的滑动平均值都是由影子变量所维护的,如果你想要获取变量的滑动平均值需要获取的是影子变量而不是变量本身. 1.滑动平均值模型文件的保存 import tensorflow as tf if __name__ == "__main__": v = tf.Variable(0.,name="

  • 对Tensorflow中的矩阵运算函数详解

    tf.diag(diagonal,name=None) #生成对角矩阵 import tensorflowas tf; diagonal=[1,1,1,1] with tf.Session() as sess: print(sess.run(tf.diag(diagonal))) #输出的结果为[[1 0 0 0] [0 1 0 0] [0 0 1 0] [0 0 0 1]] tf.diag_part(input,name=None) #功能与tf.diag函数相反,返回对角阵的对角元素 imp

  • 在Tensorflow中实现leakyRelu操作详解(高效)

    从github上转来,实在是厉害的想法,什么时候自己也能写出这种精妙的代码就好了 原地址:简易高效的LeakyReLu实现 代码如下: 我做了些改进,因为实在tensorflow中使用,就将原来的abs()函数替换成了tf.abs() import tensorflow as tf def LeakyRelu(x, leak=0.2, name="LeakyRelu"): with tf.variable_scope(name): f1 = 0.5 * (1 + leak) f2 =

  • javascript中Array()数组函数详解

    在程序语言中数组的重要性不言而喻,JavaScript中数组也是最常使用的对象之一,数组是值的有序集合,由于弱类型的原因,JavaScript中数组十分灵活.强大,不像是Java等强类型高级语言数组只能存放同一类型或其子类型元素,JavaScript在同一个数组中可以存放多种类型的元素,而且是长度也是可以动态调整的,可以随着数据增加或减少自动对数组长度做更改. Array()是一个用来构建数组的内建构造器函数.数组主要由如下三种创建方式: array = new Array() array =

  • COM组件中调用JavaScript函数详解及实例

    COM组件中调用JavaScript函数详解及实例 要求是很简单的,即有COM组件A在IE中运行,使用JavaScript(JS)调用A的方法longCalc(),该方法是一个耗时的操作,要求通知IE当前的进度.这就要求使用回调函数,设其名称为scriptCallbackFunc.实现这个技术很简单: 1 .组件方(C++) 组件A 的方法在IDL中定义: [id(2)] HRESULT longCalc([in] DOUBLE v1, [in] DOUBLE v2, [in, optional

  • 对Python3中的input函数详解

    下面介绍python3中的input函数及其在python2及pyhton3中的不同. python3中的ininput函数,首先利用help(input)函数查看函数信息: 以上信息说明input函数在python中是一个内建函数,其从标准输入中读入一个字符串,并自动忽略换行符. 也就是说所有形式的输入按字符串处理,如果想要得到其他类型的数据进行强制类型转化.默认情况下没有 提示字符串(prompt  string),在给定提示字符串下,会在读入标准输入前标准输出提示字符串.如果遇 文件结束符

  • python中的 zip函数详解及用法举例

    python中zip()函数用法举例 定义:zip([iterable, ...]) zip()是Python的一个内建函数,它接受一系列可迭代的对象作为参数,将对象中对应的元素打包成一个个tuple(元组),然后返回由这些tuples组成的list(列表).若传入参数的长度不等,则返回list的长度和参数中长度最短的对象相同.利用*号操作符,可以将list unzip(解压),看下面的例子就明白了: 示例1 x = [1, 2, 3] y = [4, 5, 6] z = [7, 8, 9] x

  • SQL中的开窗函数详解可代替聚合函数使用

    在没学习开窗函数之前,我们都知道,用了分组之后,查询字段就只能是分组字段和聚合的字段,这带来了极大的不方便,有时我们查询时需要分组,又需要查询不分组的字段,每次都要又到子查询,这样显得sql语句复杂难懂,给维护代码的人带来很大的痛苦,然而开窗函数出现了,曙光也来临了.如果要想更具体了解开窗函数,请看书<程序员的SQL金典>,开窗函数在mysql不能使用. 开窗函数与聚合函数一样,都是对行的集合组进行聚合计算.它用于为行定义一个窗口(这里的窗口是指运算将要操作的行的集合),它对一组值进行操作,不

  • C++中的Lambda函数详解

    目录 一 函数语法 二 函数应用 1.在普通函数中使用 2.在qt信号槽中使用 3.在std::sort排序函数中的使用 三 总结 一 函数语法 我们平时调用函数的时候,都是需要被调用函数的函数名,但是匿名函数就不需要函数名,而且直接写在需要调用的地方,对于以前没用过的小伙伴来说,第一眼看见了这语法可能很迷惑. C++11的基本语法格式为: [capture](parameters) -> return_type { /* ... */ } (1) [capture] :[]内为外部变量的传递方

  • python中lambda匿名函数详解

    在Python中,不通过def来声明函数名字,而是通过lambda关键字来定义的函数称为匿名函数 关键字lambda表示匿名函数 语法 lambda 参数:表达式 先写lambda关键字,然后依次写匿名函数的参数,多个参数中间用逗号连接,然后是一个冒号,冒号后面写返回的表达式 lambda函数比普通函数更简洁 匿名函数有个好处:函数没有名字,不必担心函数名冲突 匿名函数与普通函数的对比 : def sum_func(a, b, c): return a + b + c sum_lambda =

随机推荐