详解TensorFlow查看ckpt中变量的几种方法

查看TensorFlow中checkpoint内变量的几种方法

查看ckpt中变量的方法有三种:

  1. 在有model的情况下,使用tf.train.Saver进行restore
  2. 使用tf.train.NewCheckpointReader直接读取ckpt文件,这种方法不需要model。
  3. 使用tools里的freeze_graph来读取ckpt

注意:

  1. 如果模型保存为.ckpt的文件,则使用该文件就可以查看.ckpt文件里的变量。ckpt路径为 model.ckpt
  2. 如果模型保存为.ckpt-xxx-data (图结构)、.ckpt-xxx.index (参数名)、.ckpt-xxx-meta (参数值)文件,则需要同时拥有这三个文件才行。并且ckpt的路径为 model.ckpt-xxx

1. 基于model来读取ckpt文件里的变量

1.首先建立model
2.从ckpt中恢复变量

with tf.Graph().as_default() as g:
  #建立model
  images, labels = cifar10.inputs(eval_data=eval_data)
  logits = cifar10.inference(images)
  top_k_op = tf.nn.in_top_k(logits, labels, 1)
  #从ckpt中恢复变量
  sess = tf.Session()
  saver = tf.train.Saver() #saver = tf.train.Saver(...variables...) # 恢复部分变量时,只需要在Saver里指定要恢复的变量
  save_path = 'ckpt的路径'
  saver.restore(sess, save_path) # 从ckpt中恢复变量

注意:基于model来读取ckpt中变量时,model和ckpt必须匹配。

2. 使用tf.train.NewCheckpointReader直接读取ckpt文件里的变量,使用tools.inspect_checkpoint里的print_tensors_in_checkpoint_file函数打印ckpt里的东西

#使用NewCheckpointReader来读取ckpt里的变量
from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join(model_dir, "model.ckpt")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) #tf.train.NewCheckpointReader
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
  print("tensor_name: ", key)
  #print(reader.get_tensor(key))
#使用print_tensors_in_checkpoint_file打印ckpt里的内容
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file

print_tensors_in_checkpoint_file(file_name, #ckpt文件名字
                 tensor_name, # 如果为None,则默认为ckpt里的所有变量
                 all_tensors, # bool 是否打印所有的tensor,这里打印出的是tensor的值,一般不推荐这里设置为False
                 all_tensor_names) # bool 是否打印所有的tensor的name
#上面的打印ckpt的内部使用的是pywrap_tensorflow.NewCheckpointReader所以,掌握NewCheckpointReader才是王道

3.使用tools里的freeze_graph来读取ckpt

from tensorflow.python.tools import freeze_graph

freeze_graph(input_graph, #=some_graph_def.pb
       input_saver,
       input_binary,
       input_checkpoint, #=model.ckpt
       output_node_names, #=softmax
       restore_op_name,
       filename_tensor_name,
       output_graph, #='./tmp/frozen_graph.pb'
       clear_devices,
       initializer_nodes,
       variable_names_whitelist='',
       variable_names_blacklist='',
       input_meta_graph=None,
       input_saved_model_dir=None,
       saved_model_tags='serve',
       checkpoint_version=2)
#freeze_graph_test.py讲述了怎么使用freeze_grapg。

使用freeze_graph可以将图和ckpt进行合并。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持我们。

(0)

相关推荐

  • 查看TensorFlow checkpoint文件中的变量名和对应值方法

    实例如下所示: from tensorflow.python import pywrap_tensorflow checkpoint_path = os.path.join(model_dir, "model.ckpt") reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_

  • tensorflow获取变量维度信息

    tensorflow版本1.4 获取变量维度是一个使用频繁的操作,在tensorflow中获取变量维度主要用到的操作有以下三种: Tensor.shape Tensor.get_shape() tf.shape(input,name=None,out_type=tf.int32) 对上面三种操作做一下简单分析:(这三种操作先记作A.B.C) A 和 B 基本一样,只不过前者是Tensor的属性变量,后者是Tensor的函数. A 和 B 均返回TensorShape类型,而 C 返回一个1D的o

  • tensorflow 获取变量&打印权值的实例讲解

    在使用tensorflow中,我们常常需要获取某个变量的值,比如:打印某一层的权重,通常我们可以直接利用变量的name属性来获取,但是当我们利用一些第三方的库来构造神经网络的layer时,存在一种情况:就是我们自己无法定义该层的变量,因为是自动进行定义的. 比如用tensorflow的slim库时: <span style="font-size:14px;">def resnet_stack(images, output_shape, hparams, scope=None

  • TensorFlow变量管理详解

    一.TensorFlow变量管理 1. TensorFLow还提供了tf.get_variable函数来创建或者获取变量,tf.variable用于创建变量时,其功能和tf.Variable基本是等价的.tf.get_variable中的初始化方法(initializer)的参数和tf.Variable的初始化过程也类似,initializer函数和tf.Variable的初始化方法是一一对应的,详见下表. tf.get_variable和tf.Variable最大的区别就在于指定变量名称的参数

  • tensorflow创建变量以及根据名称查找变量

    环境:Ubuntu14.04,tensorflow=1.4(bazel源码安装),Anaconda python=3.6 声明变量主要有两种方法:tf.Variable和 tf.get_variable,二者的最大区别是: (1) tf.Variable是一个类,自带很多属性函数:而 tf.get_variable是一个函数; (2) tf.Variable只能生成独一无二的变量,即如果给出的name已经存在,则会自动修改生成新的变量name; (3) tf.get_variable可以用于生成

  • TensorFlow saver指定变量的存取

    今天和大家分享一下用TensorFlow的saver存取训练好的模型那点事. 1. 用saver存取变量: 2. 用saver存取指定变量. 用saver存取变量. 话不多说,先上代码 # coding=utf-8 import os import tensorflow as tf import numpy os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #有些指令集没有装,加这个不显示那些警告 w = tf.Variable([[1,2,3],[2,3,4],

  • TensorFLow用Saver保存和恢复变量

    本文为大家分享了TensorFLow用Saver保存和恢复变量的具体代码,供大家参考,具体内容如下 建立文件tensor_save.py, 保存变量v1,v2的tensor到checkpoint files中,名称分别设置为v3,v4. import tensorflow as tf # Create some variables. v1 = tf.Variable(3, name="v1") v2 = tf.Variable(4, name="v2") # Cre

  • Tensorflow 查看变量的值方法

    定义一个变量,直接输出会输出变量的属性,并不能输出变量值.那么怎么输出变量值呢?请看下面得意 import tensorflow as tf biases=tf.Variable(tf.zeros([2,3]))#定义一个2x3的全0矩阵 sess=tf.InteractiveSession()#使用InteractiveSession函数 biases.initializer.run()#使用初始化器 initializer op 的 run() 方法初始化 'biases' print(se

  • 详解TensorFlow查看ckpt中变量的几种方法

    查看TensorFlow中checkpoint内变量的几种方法 查看ckpt中变量的方法有三种: 在有model的情况下,使用tf.train.Saver进行restore 使用tf.train.NewCheckpointReader直接读取ckpt文件,这种方法不需要model. 使用tools里的freeze_graph来读取ckpt 注意: 如果模型保存为.ckpt的文件,则使用该文件就可以查看.ckpt文件里的变量.ckpt路径为 model.ckpt 如果模型保存为.ckpt-xxx-

  • 详解Spring Data JPA中Repository的接口查询方法

    目录 1.查询方法定义详解 2.搜索查询策略 3.查询创建 4.属性表达式 5.特殊参数处理 6.限制查询结果 7. repository方法返回Collections or Iterables 8.repository方法处理Null 9.查询结果流 10.异步查询结果 1.查询方法定义详解 repository代理有两种方式从方法名中派生出特定存储查询. 通过直接从方法名派生查询. 通过使用一个手动定义的查询. 可用的选项取决于实际的商店.然而,必须有一个策略来决定创建什么实际的查询. 2.

  • 详解Selenium-webdriver绕开反爬虫机制的4种方法

    之前爬美团外卖后台的时候出现的问题,各种方式拖动验证码都无法成功,包括直接控制拉动,模拟人工轨迹的随机拖动都失败了,最后发现只要用chrome driver打开页面,哪怕手动登录也不可以,猜测driver肯定是直接被识别出来了.一开始尝试了改user agent等方式,仍然不行,由于其他项目就搁置了.今天爬淘宝生意参谋又出现这个问题,经百度才知道原来chrome driver的变量有一个特征码,网站可以直接根据特征码判断,经百度发现有4种方法可以解决,记录一下自己做的尝试. 1.mitproxy

  • 详解用Python进行时间序列预测的7种方法

    数据准备 数据集(JetRail高铁的乘客数量)下载. 假设要解决一个时序问题:根据过往两年的数据(2012 年 8 月至 2014 年 8月),需要用这些数据预测接下来 7 个月的乘客数量. import pandas as pd import numpy as np import matplotlib.pyplot as plt df = pd.read_csv('train.csv') df.head() df.shape 依照上面的代码,我们获得了 2012-2014 年两年每个小时的乘

  • 详解git的分支与合并的两种方法

    如何将两个分支合并到一起.就是说我们新建一个分支,在其上开发某个新功能,开发完成后再合并回主线. 1.   git merge 咱们先来看一下第一种方法 -- git merge 在 Git 中合并两个分支时会产生一个特殊的提交记录,它有两个父节点.翻译成自然语言相当于:"我要把这两个父节点本身及它们所有的祖先都包含进来."下面具体解释. # 创建新分支 bugFix git branch bugFix # 切换到该分支 git checkout bugFix # 提交一次 git c

  • 详解IOS判断当前网络状态的三种方法

    在项目中,为了好的用户体验,有些场景必须线判断网络状态,然后才能决定该干嘛.比如视频播放,需要线判断是Wifi还是4G,Wifi直接播放,4G先提示用户.获取网络状态的方法大概有三种: 1. Reachability 这是苹果的官方演示demo中使用到的方法,我们可以到苹果官方文档里下载Demo(点击左上角Download Sample Code 即可下载),然后把Demo里的Reachability.h和.m考到自己项目中,并在Build Phases 的 Link Binary 添加Syst

  • 详解Flutter和Dart取消Future的三种方法

    目录 使用异步包(推荐) 完整示例 使用 timeout() 方法 快速示例 将Future转换为流 快速示例 结论 使用异步包(推荐) async包由 Dart 编程语言的作者开发和发布.它提供了dart:async风格的实用程序来增强异步计算.可以帮助我们取消Future的是CancelableOperation类: var myCancelableFuture = CancelableOperation.fromFuture( Future<T> inner, { FutureOr on

  • 详解Android GLide图片加载常用几种方法

    目录 缓存浅析 GLide图片加载方法 图片加载周期 图片格式(Bitmap,Gif) 缓存 集成网络框架 权限 占位符 淡入效果 变换 启动页/广告页 banner 固定宽高 圆角 圆形 总结 缓存浅析 为啥要做缓存? android默认给每个应用只分配16M的内存,所以如果加载过多的图片,为了 防止内存溢出 ,应该将图片缓存起来. 图片的三级缓存分别是: 1.内存缓存 2.本地缓存 3.网络缓存 其中,内存缓存应优先加载,它速度最快:本地缓存次优先加载,它速度也快:网络缓存不应该优先加载,它

  • 详解pandas获取Dataframe元素值的几种方法

    可以通过遍历的方法: pandas按行按列遍历Dataframe的几种方式:https://www.jb51.net/article/172623.htm 选择列 使用类字典属性,返回的是Series类型 data['w'] 遍历Series for index in data['w'] .index: time_dis = data['w'] .get(index) pandas.DataFrame.at 根据行索引和列名,获取一个元素的值 >>> df = pd.DataFrame(

  • 详解设计模式在Spring中的应用(9种)

    设计模式作为工作学习中的枕边书,却时常处于勤说不用的尴尬境地,也不是我们时常忘记,只是一直没有记忆. 今天,在IT学习者网站就设计模式的内在价值做一番探讨,并以spring为例进行讲解,只有领略了其设计的思想理念,才能在工作学习中运用到"无形". Spring作为业界的经典框架,无论是在架构设计方面,还是在代码编写方面,都堪称行内典范.好了,话不多说,开始今天的内容. spring中常用的设计模式达到九种,我们一一举例: 第一种:简单工厂 又叫做静态工厂方法(StaticFactory

随机推荐