tensorflow 获取checkpoint中的变量列表实例

方式1:静态获取,通过直接解析checkpoint文件获取变量名及变量值

通过

reader = tf.train.NewCheckpointReader(model_path)

或者通过:

from tensorflow.python import pywrap_tensorflow
reader = pywrap_tensorflow.NewCheckpointReader(model_path)

代码:

model_path = "./checkpoints/model.ckpt-75000"
## 下面两个reader作用等价
#reader = pywrap_tensorflow.NewCheckpointReader(model_path)
reader = tf.train.NewCheckpointReader(model_path)

## 用reader获取变量字典,key是变量名,value是变量的shape
var_to_shape_map = reader.get_variable_to_shape_map()
for var_name in var_to_shape_map.keys():
 #用reader获取变量值
 var_value = reader.get_tensor(var_name)

 print("var_name",var_name)
 print("var_value",var_value)

方式2:动态获取,先加载checkpoint模型,然后用graph.get_tensor_by_name()获取变量值

代码 (注意:要先在脚本中构建model中对应的变量及scope):

 model_path = "./checkpoints/model.ckpt-75000"
 config = tf.ConfigProto()
 config.gpu_options.allow_growth = True
 with tf.Session(config=config) as sess:
  ## 获取待加载的变量列表
  trainable_vars = tf.trainable_variables()
  g_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope="generator")
  d_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope='discriminator')
  flow_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope='flow_net')
  var_restore = g_vars + d_vars

  ## 仅加载目标变量
  loader = tf.train.Saver(var_restore)
  loader.restore(sess,model_path)

  ## 显示加载的变量值
  graph = tf.get_default_graph()
  for var in var_restore:
   tensor = graph.get_tensor_by_name(var.name)
   print("=======变量名=======",tensor)
   print("-------变量值-------",sess.run(tensor))

以上这篇tensorflow 获取checkpoint中的变量列表实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • Tensorflow: 从checkpoint文件中读取tensor方式

    在使用pre-train model时候,我们需要restore variables from checkpoint files. 经常出现在checkpoint 中找不到"Tensor name not found". 这时候需要查看一下ckpt中到底有哪些变量 import os from tensorflow.python import pywrap_tensorflow checkpoint_path = os.path.join(model_dir, "model.

  • tensorflow实现训练变量checkpoint的保存与读取

    1.保存变量 先创建(在tf.Session()之前)saver saver = tf.train.Saver(tf.global_variables(),max_to_keep=1) #max_to_keep这个保证只保存最后一次training的训练数据 然后在训练的循环里面 checkpoint_path = os.path.join(Path, 'model.ckpt') saver.save(session, checkpoint_path, global_step=step) #这里

  • TensorFlow 输出checkpoint 中的变量名与变量值方式

    废话不多说,直接看代码吧! import os from tensorflow.python import pywrap_tensorflow model_dir="/xxxxxxxxx/model.ckpt" #checkpoint的文件位置 # Read data from checkpoint file reader = pywrap_tensorflow.NewCheckpointReader(model_dir) var_to_shape_map = reader.get_v

  • tensorflow 实现从checkpoint中获取graph信息

    代码: import tensorflow as tf sess = tf.Session() check_point_path = 'variables' saver = tf.train.import_meta_graph('variables/save_variables.ckpt.meta') saver.restore(sess, tf.train.latest_checkpoint(check_point_path)) graph = tf.get_default_graph() #

  • tensorflow 获取checkpoint中的变量列表实例

    方式1:静态获取,通过直接解析checkpoint文件获取变量名及变量值 通过 reader = tf.train.NewCheckpointReader(model_path) 或者通过: from tensorflow.python import pywrap_tensorflow reader = pywrap_tensorflow.NewCheckpointReader(model_path) 代码: model_path = "./checkpoints/model.ckpt-7500

  • python 获取url中的参数列表实例

    Python的urlparse有对url的解析,从而获得url中的参数列表 import urlparse urldata = "http://en.wikipedia.org/w/api.php?action=query&ctitle=FA" result = urlparse.urlparse(urldata) print result print urlparse.parse_qs(result.query) 输出: ParseResult(scheme='http',

  • tensorflow模型的save与restore,及checkpoint中读取变量方式

    创建一个NN import tensorflow as tf import numpy as np #fake data x = np.linspace(-1, 1, 100)[:, np.newaxis] #shape(100,1) noise = np.random.normal(0, 0.1, size=x.shape) y = np.power(x, 2) + noise #shape(100,1) + noise tf_x = tf.placeholder(tf.float32, x.

  • python 通过类中一个方法获取另一个方法变量的实例

    1.在进行接口自动化测试过程中,经常出现接口数据的互相调用,如一些操作需要调用登陆之后返回的session或者token,下面同个简单的方法进行讲解 class A(): def a_add_b(self): a=10 b=20 self.S=a+b print (self.S) return self.S def c_add_ab(self): c=30 s=c+self.S print (s) t=A() t.a_add_b() t.c_add_ab() 运行之后,打印的结果为 30 60

  • tensorflow 打印内存中的变量方法

    法一: 循环打印 模板 for (x, y) in zip(tf.global_variables(), sess.run(tf.global_variables())): print '\n', x, y 实例 # coding=utf-8 import tensorflow as tf def func(in_put, layer_name, is_training=True): with tf.variable_scope(layer_name, reuse=tf.AUTO_REUSE):

  • 打印tensorflow恢复模型中所有变量与操作节点方式

    我就废话不多说了,大家还是直接看代码吧! #参数恢复 self.sess=tf.Session() saver = tf.train.import_meta_graph(os.path.join(model_fullpath,'model.ckpt-7.meta')) module_file = tf.train.latest_checkpoint(model_fullpath) saver.restore(self.sess, module_file) variable_names = [v.

  • 使用Math.max,Math.min获取数组中的最值实例

    Math.min()和Math.max()用法相似. 两个方法用来获取给定的一组数值中的最大值或最小值,但是却不接受数组作为参数. 当然可以写个函数遍历比较之类的等等,此处不描述. 有两个快捷的方法可以接受数组类型参数: 1 . Math.min.apply(null, arr) >>>Math.min.apply(null, [2,1,3]) <<<1 唉?不是不能接收数组类型的参数吗?这是apply方法的特性,apply方法第二个参数为参数的数组,明白了吧,虽然我们

  • 解决vue2 在mounted函数无法获取prop中的变量问题

    如下所示: props: { example: { type: Object, default() { }, }, }, watch: { example: function(newVal,oldVal){ // newVal 为改变后的值 // 继续要处理的事件 }, }, 使用watch 替代 mounted. 通过watch属性来响应数据的变化,当数据改变时执行异步操作. 总结 以上所述是小编给大家介绍的解决vue2 在mounted函数无法获取prop中的变量问题,希望对大家有所帮助,如

  • Python 在字符串中加入变量的实例讲解

    有时候,我们需要在字符串中加入相应的变量,以下提供了几种字符串加入变量的方法: 1.+ 连字符 name = 'zhangsan' print('my name is '+name) #结果为 my name is zhangsan 2.% 字符 name = 'zhangsan' age = 25 price = 4500.225 print('my name is %s'%(name)) print('i am %d'%(age)+' years old') print('my price

随机推荐