Tensorflow 如何从checkpoint文件中加载变量名和变量值

假设你已经经过上千次的迭代,并且得到了以下模型:

则从这些checkpoint文件中加载变量名和变量值代码如下:

model_dir = './ckpt-182802'
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
reader = pywrap_tensorflow.NewCheckpointReader(model_dir)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
     print("tensor_name: ", key)
     print(reader.get_tensor(key)) # Remove this is you want to print only variable names

Mnist

下面将给出一个基于卷积神经网络的手写数字识别样例:

# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.framework import graph_util
log_dir = './tensorboard'
mnist = input_data.read_data_sets(train_dir="./mnist_data",one_hot=True)
if tf.gfile.Exists(log_dir):
        tf.gfile.DeleteRecursively(log_dir)
tf.gfile.MakeDirs(log_dir)

#定义输入数据mnist图片大小28*28*1=784,None表示batch_size
x = tf.placeholder(dtype=tf.float32,shape=[None,28*28],name="input")
#定义标签数据,mnist共10类
y_ = tf.placeholder(dtype=tf.float32,shape=[None,10],name="y_")
#将数据调整为二维数据,w*H*c---> 28*28*1,-1表示N张
image = tf.reshape(x,shape=[-1,28,28,1])

#第一层,卷积核={5*5*1*32},池化核={2*2*1,1*2*2*1}
w1 = tf.Variable(initial_value=tf.random_normal(shape=[5,5,1,32],stddev=0.1,dtype=tf.float32,name="w1"))
b1= tf.Variable(initial_value=tf.zeros(shape=[32]))
conv1 = tf.nn.conv2d(input=image,filter=w1,strides=[1,1,1,1],padding="SAME",name="conv1")
relu1 = tf.nn.relu(tf.nn.bias_add(conv1,b1),name="relu1")
pool1 = tf.nn.max_pool(value=relu1,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")
#shape={None,14,14,32}
#第二层,卷积核={5*5*32*64},池化核={2*2*1,1*2*2*1}
w2 = tf.Variable(initial_value=tf.random_normal(shape=[5,5,32,64],stddev=0.1,dtype=tf.float32,name="w2"))
b2 = tf.Variable(initial_value=tf.zeros(shape=[64]))
conv2 = tf.nn.conv2d(input=pool1,filter=w2,strides=[1,1,1,1],padding="SAME")
relu2 = tf.nn.relu(tf.nn.bias_add(conv2,b2),name="relu2")
pool2 = tf.nn.max_pool(value=relu2,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME",name="pool2")
#shape={None,7,7,64}
#FC1
w3 = tf.Variable(initial_value=tf.random_normal(shape=[7*7*64,1024],stddev=0.1,dtype=tf.float32,name="w3"))
b3 = tf.Variable(initial_value=tf.zeros(shape=[1024]))
#关键,进行reshape
input3 = tf.reshape(pool2,shape=[-1,7*7*64],name="input3")
fc1 = tf.nn.relu(tf.nn.bias_add(value=tf.matmul(input3,w3),bias=b3),name="fc1")
#shape={None,1024}
#FC2
w4 = tf.Variable(initial_value=tf.random_normal(shape=[1024,10],stddev=0.1,dtype=tf.float32,name="w4"))
b4 = tf.Variable(initial_value=tf.zeros(shape=[10]))
fc2 = tf.nn.bias_add(value=tf.matmul(fc1,w4),bias=b4,name="logit")
#shape={None,10}
#定义交叉熵损失
# 使用softmax将NN计算输出值表示为概率
y = tf.nn.softmax(fc2,name="out")

# 定义交叉熵损失函数
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=fc2,labels=y_)
loss = tf.reduce_mean(cross_entropy)
tf.summary.scalar('Cross_Entropy',loss)
#定义solver
train = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss=loss)
for var in tf.trainable_variables():
	print var
#train = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss=loss)

#定义正确值,判断二者下标index是否相等
correct_predict = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
#定义如何计算准确率
accuracy = tf.reduce_mean(tf.cast(correct_predict,dtype=tf.float32),name="accuracy")
tf.summary.scalar('Training_ACC',accuracy)
#定义初始化op
merged = tf.summary.merge_all()
init = tf.global_variables_initializer()
saver = tf.train.Saver()
#训练NN
with tf.Session() as session:
    session.run(fetches=init)
    writer = tf.summary.FileWriter(log_dir,session.graph) #定义记录日志的位置
    for i in range(0,500):
        xs, ys = mnist.train.next_batch(100)
        session.run(fetches=train,feed_dict={x:xs,y_:ys})
        if i%10 == 0:
            train_accuracy,summary = session.run(fetches=[accuracy,merged],feed_dict={x:xs,y_:ys})
            writer.add_summary(summary,i)
            print(i,"accuracy=",train_accuracy)
    '''
    #训练完成后,将网络中的权值转化为常量,形成常量graph,注意:需要x与label
    constant_graph = graph_util.convert_variables_to_constants(sess=session,
                                                            input_graph_def=session.graph_def,
                                                            output_node_names=['out','y_','input'])
    #将带权值的graph序列化,写成pb文件存储起来
    with tf.gfile.FastGFile("lenet.pb", mode='wb') as f:
        f.write(constant_graph.SerializeToString())
    '''
    saver.save(session,'./ckpt')

补充:查看tensorflow产生的checkpoint文件内容的方法

tensorflow在保存权重模型时多使用tf.train.Saver().save 函数进行权重保存,保存的ckpt文件无法直接打开,但tensorflow提供了相关函数 tf.train.NewCheckpointReader 可以对ckpt文件进行权重查看。

import os
from tensorflow.python import pywrap_tensorflow

checkpoint_path = os.path.join('modelckpt', "fc_nn_model")
# Read data from checkpoint file
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
# Print tensor name and values
for key in var_to_shape_map:
    print("tensor_name: ", key)
    print(reader.get_tensor(key))

其中‘modelckpt'是存放.ckpt文件的文件夹,"fc_nn_model"是文件名,如下图所示。

var_to_shape_map是一个字典,其中的键值是变量名,对应的值是该变量的形状,如{‘LSTM_input/bias_LSTM/Adam_1': [128]}。

想要查看某变量值时,需要调用get_tensor函数,即输入以下代码:

reader.get_tensor('LSTM_input/bias_LSTM/Adam_1')

以上为个人经验,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • TensorFlow实现模型断点训练,checkpoint模型载入方式

    深度学习中,模型训练一般都需要很长的时间,由于很多原因,导致模型中断训练,下面介绍继续断点训练的方法. 方法一:载入模型时,不必指定迭代次数,一般默认最新 # 保存模型 saver = tf.train.Saver(max_to_keep=1) # 最多保留最新的模型 # 开启会话 with tf.Session() as sess: # saver.restore(sess, './log/' + "model_savemodel.cpkt-" + str(20000)) sess.

  • tensorflow 实现从checkpoint中获取graph信息

    代码: import tensorflow as tf sess = tf.Session() check_point_path = 'variables' saver = tf.train.import_meta_graph('variables/save_variables.ckpt.meta') saver.restore(sess, tf.train.latest_checkpoint(check_point_path)) graph = tf.get_default_graph() #

  • TensorFlow实现checkpoint文件转换为pb文件

    由于项目需要,需要将TensorFlow保存的模型从ckpt文件转换为pb文件. import os from tensorflow.python import pywrap_tensorflow from net2use import inception_resnet_v2_small#这里使用自己定义的模型函数即可 import tensorflow as tf if __name__=='__main__': pb_file = "./model/output.pb" ckpt_

  • tensorflow 获取checkpoint中的变量列表实例

    方式1:静态获取,通过直接解析checkpoint文件获取变量名及变量值 通过 reader = tf.train.NewCheckpointReader(model_path) 或者通过: from tensorflow.python import pywrap_tensorflow reader = pywrap_tensorflow.NewCheckpointReader(model_path) 代码: model_path = "./checkpoints/model.ckpt-7500

  • tensorflow模型的save与restore,及checkpoint中读取变量方式

    创建一个NN import tensorflow as tf import numpy as np #fake data x = np.linspace(-1, 1, 100)[:, np.newaxis] #shape(100,1) noise = np.random.normal(0, 0.1, size=x.shape) y = np.power(x, 2) + noise #shape(100,1) + noise tf_x = tf.placeholder(tf.float32, x.

  • TensorFlow 输出checkpoint 中的变量名与变量值方式

    废话不多说,直接看代码吧! import os from tensorflow.python import pywrap_tensorflow model_dir="/xxxxxxxxx/model.ckpt" #checkpoint的文件位置 # Read data from checkpoint file reader = pywrap_tensorflow.NewCheckpointReader(model_dir) var_to_shape_map = reader.get_v

  • Tensorflow 如何从checkpoint文件中加载变量名和变量值

    假设你已经经过上千次的迭代,并且得到了以下模型: 则从这些checkpoint文件中加载变量名和变量值代码如下: model_dir = './ckpt-182802' import tensorflow as tf from tensorflow.python import pywrap_tensorflow reader = pywrap_tensorflow.NewCheckpointReader(model_dir) var_to_shape_map = reader.get_varia

  • 利用反射获取Java类中的静态变量名及变量值的简单实例

    JAVA可以通过反射获取成员变量和静态变量的名称,局部变量就不太可能拿到了. public class Test { public static void main(String[] args) throws Exception { // TODO Auto-generated method stub //获取所有变量的值 Class clazz = Class.forName("com.qianmingxs.ScoreTable"); Field[] fields = clazz.g

  • JS文件中加载jquery.js的实例代码

    本文表述了JS文件中加载jquery.js的方法,具有很好的参考价值,希望对大家有所帮助. 最近有一个需求: 1.在一个html中只能引入一个JS文件 不能有JS代码和其他JS文件的引入: 2.这个JS文件中 还要引入其他的JS文件: 3.所有JS功能都写在这个JS文件中 这些代码用到了jquery相关的东东 所以这里第一个需要解决的就是怎么引入jquery.js. 在网上搜索了很多方法都不太实用,由于我自己离开WEB多年 最后向朋友询问得到以下代码: 1.js // by firefoxmmx

  • Python实现从文件中加载数据的方法详解

    前几篇都是手动录入或随机函数产生的数据.实际有许多类型的文件,以及许多方法,用它们从文件中提取数据来图形化. 比如之前python基础(12)介绍打开文件的方式,可直接读取文件中的数据,扩大了我们的数据来源.下面,将展示几种方法. 我们将使用内置的 csv 模块加载CSV文件 CSV文件是一种特殊的文本文件,文件中的数据以逗号作为分隔符,很适合进行数据的解析.先用excle建立如下表格和数据,另存为csv格式文件,放到代码目录下. 包含在Python标准库中自带CSV 模块,我们只需要impor

  • tensorflow实现从.ckpt文件中读取任意变量

    思路有些混乱,希望大家能理解我的意思. 看了faster rcnn的tensorflow代码,关于fix_variables的作用我不是很明白,所以写了以下代码,读取了预训练模型vgg16得fc6和fc7的参数,以及faster rcnn中heat_to_tail中的fc6和fc7,将它们做了对比,发现结果不一样,说明vgg16的fc6和fc7只是初始化了faster rcnn中heat_to_tail中的fc6和fc7,之后后者被训练. 具体读取任意变量的代码如下: import tensor

  • tensorflow实现打印ckpt模型保存下的变量名称及变量值

    有时候会需要通过从保存下来的ckpt文件来观察其保存下来的训练完成的变量值. ckpt文件名列表:(一般是三个文件) xxxxx.ckpt.data-00000-of-00001 xxxxx.ckpt.index xxxxx.ckpt.meta import os from tensorflow.python import pywrap_tensorflow checkpoint_path = os.path.join("文件夹路径", "xxxxx.ckpt")

  • 解析Java和Eclipse中加载本地库(.dll文件)的详细说明

    最近在做的工作要用到本地方法,需要在Java中加载不少动态链接库(以下为方便延用Windows平台下的简写dll,但并不局限于Windows).刚刚把程序跑通,赶紧把一些心得写出来,mark.也希望对大家的类似工作有所帮助首先,应当明确,dll有两类:(1)Java所依赖的dll和,(2)dll所依赖的dll.正是由于第(2)种dll的存在,才导致了java中加载dll的复杂性大大增加,许多说法都是这样的,但我实验的结果却表明似乎没有那么复杂,后面会予以详细阐述.其次,Java中加载dll的方式

  • JavaWeb项目中dll文件动态加载方法解析(详细步骤)

    相信很多做Java的朋友都有过用Java调用JNI实现调用C或C++方法的经历,那么Java Web中又如何实现DLL/SO文件的动态加载方法呢.今天就给大家带来一篇JAVA Web项目中DLL/SO文件动态加载方法的文章. 在Java Web项目中,我们经常会用到通过JNI调用dll动态库文件来实现一些JAVA不能实现的功能,或者是一些第三方dll插件.通常的做法是将这些dll文件复制到 %JAVA_HOME%\jre\bin\ 文件夹或者 应用中间件(Tomcat|Weblogic)的bin

  • PHP中的use关键字及文件的加载详解

    前言 可能在大家经常使用框架,写一个Controller或者Model的时候,写了好多use,但是并没有写文件加载的代码,就以为use可以进行文件的自动加载了. 详细介绍 其实,现在流行的php框架,都是基于MVC模式的,大量的使用了命名空间,以提高程序的灵活性.那么框架是怎么实现将use关键字所声明的类库对应的脚本文件进行加载的那? (1):在通过use关键字进行声明类库的声明的时候,并不会进行脚本的加载,而是在脚本文件真正使用到所对应的类库的时候才会进行加载(这就是所谓延迟加载). (2):

随机推荐