tensorflow模型的save与restore,及checkpoint中读取变量方式

创建一个NN

import tensorflow as tf
import numpy as np

#fake data
x = np.linspace(-1, 1, 100)[:, np.newaxis] #shape(100,1)
noise = np.random.normal(0, 0.1, size=x.shape)
y = np.power(x, 2) + noise  #shape(100,1) + noise
tf_x = tf.placeholder(tf.float32, x.shape) #input x
tf_y = tf.placeholder(tf.float32, y.shape) #output y
l = tf.layers.dense(tf_x, 10, tf.nn.relu) #hidden layer
o = tf.layers.dense(l, 1)     #output layer
loss = tf.losses.mean_squared_error(tf_y, o ) #compute loss
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(loss)

1.使用save对模型进行保存

sess= tf.Session()
sess.run(tf.global_variables_initializer())  #initialize var in graph
saver = tf.train.Saver() # define a saver for saving and restoring
for step in range(100):   #train
 sess.run(train_op,{tf_x:x, tf_y:y})
saver.save(sess, 'params/params.ckpt', write_meta_graph=False) # mate_graph is not recommend

生成三个文件,分别是checkpoint,.ckpt.data-00000-of-00001,.ckpt.index

2.使用restore对提取模型

在提取模型时,需要将模型结构再定义一遍,再将各参数加载出来

#bulid entire net again and restore
tf_x = tf.placeholder(tf.float32, x.shape)
tf_y = tf.placeholder(tf.float32, y.shape)
l_ = tf.layers.dense(tf_x, 10, tf.nn.relu)
o_ = tf.layers.dense(l_, 1)
loss_ = tf.losses.mean_squared_error(tf_y, o_)

sess = tf.Session()
# don't need to initialize variables, just restoring trained variables
saver = tf.train.Saver() # define a saver for saving and restoring
saver.restore(sess, './params/params.ckpt')

3.有时会报错Not found:b1 not found in checkpoint

这时我们想知道我在文件中到底保存了什么内容,即需要读取出checkpoint中的tensor

import os
from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join('params','params.ckpt')
# Read data from checkpoint file
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
# Print tensor name and value
f = open('params.txt','w')
for key in var_to_shape_map: # write tensors' names and values in file
 print(key,file=f)
 print(reader.get_tensor(key),file=f)
f.close()

运行后生成一个params.txt文件,在其中可以看到模型的参数。

补充知识:TensorFlow按时间保存检查点

一 实例

介绍一种更简便地保存检查点功能的方法——tf.train.MonitoredTrainingSession函数,该函数可以直接实现保存及载入检查点模型的文件。

演示使用MonitoredTrainingSession函数来自动管理检查点文件。

二 代码

import tensorflow as tf
tf.reset_default_graph()
global_step = tf.train.get_or_create_global_step()
step = tf.assign_add(global_step, 1)
with tf.train.MonitoredTrainingSession(checkpoint_dir='log/checkpoints',save_checkpoint_secs = 2) as sess:
 print(sess.run([global_step]))
 while not sess.should_stop():
  i = sess.run( step)
  print( i)

三 运行结果

1 第一次运行后,会发现log文件夹下产生如下文件

2 第二次运行后,结果如下:

INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Restoring parameters from log/checkpoints\model.ckpt-15147
INFO:tensorflow:Saving checkpoints for 15147 into log/checkpoints\model.ckpt.
[15147]
15148
15149
15150
15151
15152
15153
15154
15155
15156
15157
15158
15159

四 说明

本例是按照训练时间来保存的。通过指定save_checkpoint_secs参数的具体秒数,来设置每训练多久保存一次检查点。

可见程序自动载入检查点是从第15147次开始运行的。

五 注意

1 如果不设置save_checkpoint_secs参数,默认的保存时间是10分钟,这种按照时间保存的模式更适合用于使用大型数据集来训练复杂模型的情况。

2 使用该方法,必须要定义global_step变量,否则会报错误。

以上这篇tensorflow模型的save与restore,及checkpoint中读取变量方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • 终端命令查看TensorFlow版本号及路径的方法

    如图,简单易懂,先激活tensorflow,然后进入python,输入python语句执行查询: 需要注意的是一定要在激活tensorflow环境后再输入python命令,否则会识别不到tensorflow,可以看到在使用python前后命令前面都是有"(tensorflow)"的. 以上这篇终端命令查看TensorFlow版本号及路径的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们.

  • 浅谈Tensorflow由于版本问题出现的几种错误及解决方法

    1.AttributeError: 'module' object has no attribute 'rnn_cell' S:将tf.nn.rnn_cell替换为tf.contrib.rnn 2.TypeError: Expected int32, got list containing Tensors of type '_Message' instead. S:由于tf.concat的问题,将tf.concat(1, [conv1, conv2]) 的格式替换为tf.concat( [con

  • tensorflow2.0保存和恢复模型3种方法

    方法1:只保存模型的权重和偏置 这种方法不会保存整个网络的结构,只是保存模型的权重和偏置,所以在后期恢复模型之前,必须手动创建和之前模型一模一样的模型,以保证权重和偏置的维度和保存之前的相同. tf.keras.model类中的save_weights方法和load_weights方法,参数解释我就直接搬运官网的内容了. save_weights( filepath, overwrite=True, save_format=None ) Arguments: filepath: String,

  • TensorFlow实现打印每一层的输出

    在test.py中可以通过如下代码直接生成带weight的pb文件,也可以通过tf官方的freeze_graph.py将ckpt转为pb文件. constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def,['net_loss/inference/encode/conv_output/conv_output']) with tf.gfile.FastGFile('net_model.pb', mod

  • tensorflow模型的save与restore,及checkpoint中读取变量方式

    创建一个NN import tensorflow as tf import numpy as np #fake data x = np.linspace(-1, 1, 100)[:, np.newaxis] #shape(100,1) noise = np.random.normal(0, 0.1, size=x.shape) y = np.power(x, 2) + noise #shape(100,1) + noise tf_x = tf.placeholder(tf.float32, x.

  • 将tensorflow模型打包成PB文件及PB文件读取方式

    1. tensorflow模型文件打包成PB文件 import tensorflow as tf from tensorflow.python.tools import freeze_graph with tf.Graph().as_default(): with tf.device("/cpu:0"): config = tf.ConfigProto(allow_soft_placement=True) with tf.Session(config=config).as_defaul

  • tensorflow 获取checkpoint中的变量列表实例

    方式1:静态获取,通过直接解析checkpoint文件获取变量名及变量值 通过 reader = tf.train.NewCheckpointReader(model_path) 或者通过: from tensorflow.python import pywrap_tensorflow reader = pywrap_tensorflow.NewCheckpointReader(model_path) 代码: model_path = "./checkpoints/model.ckpt-7500

  • Tensorflow: 从checkpoint文件中读取tensor方式

    在使用pre-train model时候,我们需要restore variables from checkpoint files. 经常出现在checkpoint 中找不到"Tensor name not found". 这时候需要查看一下ckpt中到底有哪些变量 import os from tensorflow.python import pywrap_tensorflow checkpoint_path = os.path.join(model_dir, "model.

  • TensorFlow 输出checkpoint 中的变量名与变量值方式

    废话不多说,直接看代码吧! import os from tensorflow.python import pywrap_tensorflow model_dir="/xxxxxxxxx/model.ckpt" #checkpoint的文件位置 # Read data from checkpoint file reader = pywrap_tensorflow.NewCheckpointReader(model_dir) var_to_shape_map = reader.get_v

  • TensorFlow模型保存/载入的两种方法

    TensorFlow 模型保存/载入 我们在上线使用一个算法模型的时候,首先必须将已经训练好的模型保存下来.tensorflow保存模型的方式与sklearn不太一样,sklearn很直接,一个sklearn.externals.joblib的dump与load方法就可以保存与载入使用.而tensorflow由于有graph, operation 这些概念,保存与载入模型稍显麻烦. 一.基本方法 网上搜索tensorflow模型保存,搜到的大多是基本的方法.即 保存 定义变量 使用saver.s

  • TensorFlow模型保存和提取的方法

    一.TensorFlow模型保存和提取方法 1. TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提取.tf.train.Saver对象saver的save方法将TensorFlow模型保存到指定路径中,saver.save(sess,"Model/model.ckpt") ,实际在这个文件目录下会生成4个人文件: checkpoint文件保存了一个录下多有的模型文件列表,model.ckpt.meta保存了TensorFlow计算图的结构信息,model

  • 浅谈Tensorflow模型的保存与恢复加载

    近期做了一些反垃圾的工作,除了使用常用的规则匹配过滤等手段,也采用了一些机器学习方法进行分类预测.我们使用TensorFlow进行模型的训练,训练好的模型需要保存,预测阶段我们需要将模型进行加载还原使用,这就涉及TensorFlow模型的保存与恢复加载. 总结一下Tensorflow常用的模型保存方式. 保存checkpoint模型文件(.ckpt) 首先,TensorFlow提供了一个非常方便的api,tf.train.Saver()来保存和还原一个机器学习模型. 模型保存 使用tf.trai

  • 解决tensorflow模型参数保存和加载的问题

    终于找到bug原因!记一下:还是不熟悉平台的原因造成的! Q:为什么会出现两个模型对象在同一个文件中一起运行,当直接读取他们分开运行时训练出来的模型会出错,而且总是有一个正确,一个读取错误? 而 直接在同一个文件又训练又重新加载模型预测不出错,而且更诡异的是此时用分文件里的对象加载模型不会出错? model.py,里面含有 ModelV 和 ModelP,另外还有 modelP.py 和 modelV.py 分别只含有 ModelP 和 ModeV 这两个对象,先使用 modelP.py 和 m

  • TensorFlow 模型载入方法汇总(小结)

    一.TensorFlow常规模型加载方法 保存模型 tf.train.Saver()类,.save(sess, ckpt文件目录)方法 参数名称 功能说明 默认值 var_list Saver中存储变量集合 全局变量集合 reshape 加载时是否恢复变量形状 True sharded 是否将变量轮循放在所有设备上 True max_to_keep 保留最近检查点个数 5 restore_sequentially 是否按顺序恢复变量,模型较大时顺序恢复内存消耗小 True var_list是字典

随机推荐