Tensorflow之Saver的用法详解

Saver的用法

1. Saver的背景介绍

我们经常在训练完一个模型之后希望保存训练的结果,这些结果指的是模型的参数,以便下次迭代的训练或者用作测试。Tensorflow针对这一需求提供了Saver类。

Saver类提供了向checkpoints文件保存和从checkpoints文件中恢复变量的相关方法。Checkpoints文件是一个二进制文件,它把变量名映射到对应的tensor值 。

只要提供一个计数器,当计数器触发时,Saver类可以自动的生成checkpoint文件。这让我们可以在训练过程中保存多个中间结果。例如,我们可以保存每一步训练的结果。

为了避免填满整个磁盘,Saver可以自动的管理Checkpoints文件。例如,我们可以指定保存最近的N个Checkpoints文件。

2. Saver的实例

下面以一个例子来讲述如何使用Saver类

import tensorflow as tf
import numpy as np
x = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 4
w = tf.Variable(tf.random_normal([1], -1, 1))
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + b
loss = tf.reduce_mean(tf.square(y - y_predict))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)
isTrain = False
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = ''
saver = tf.train.Saver() # defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))
with tf.Session() as sess:
  sess.run(tf.initialize_all_variables())
  if isTrain:
    for i in xrange(train_steps):
      sess.run(train, feed_dict={x: x_data})
      if (i + 1) % checkpoint_steps == 0:
        saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)
  else:
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
      saver.restore(sess, ckpt.model_checkpoint_path)
    else:
      pass
    print(sess.run(w))
    print(sess.run(b)) 
  1. isTrain:用来区分训练阶段和测试阶段,True表示训练,False表示测试
  2. train_steps:表示训练的次数,例子中使用100
  3. checkpoint_steps:表示训练多少次保存一下checkpoints,例子中使用50
  4. checkpoint_dir:表示checkpoints文件的保存路径,例子中使用当前路径

2.1 训练阶段

使用Saver.save()方法保存模型:

  1. sess:表示当前会话,当前会话记录了当前的变量值
  2. checkpoint_dir + 'model.ckpt':表示存储的文件名
  3. global_step:表示当前是第几步

训练完成后,当前目录底下会多出5个文件。

打开名为“checkpoint”的文件,可以看到保存记录,和最新的模型存储位置。

2.1测试阶段

测试阶段使用saver.restore()方法恢复变量:

sess:表示当前会话,之前保存的结果将被加载入这个会话

ckpt.model_checkpoint_path:表示模型存储的位置,不需要提供模型的名字,它会去查看checkpoint文件,看看最新的是谁,叫做什么。

运行结果如下图所示,加载了之前训练的参数w和b的结果

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持我们。

您可能感兴趣的文章:

  • TensorFlow saver指定变量的存取
  • TensorFLow用Saver保存和恢复变量
(0)

相关推荐

  • TensorFLow用Saver保存和恢复变量

    本文为大家分享了TensorFLow用Saver保存和恢复变量的具体代码,供大家参考,具体内容如下 建立文件tensor_save.py, 保存变量v1,v2的tensor到checkpoint files中,名称分别设置为v3,v4. import tensorflow as tf # Create some variables. v1 = tf.Variable(3, name="v1") v2 = tf.Variable(4, name="v2") # Cre

  • TensorFlow saver指定变量的存取

    今天和大家分享一下用TensorFlow的saver存取训练好的模型那点事. 1. 用saver存取变量: 2. 用saver存取指定变量. 用saver存取变量. 话不多说,先上代码 # coding=utf-8 import os import tensorflow as tf import numpy os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #有些指令集没有装,加这个不显示那些警告 w = tf.Variable([[1,2,3],[2,3,4],

  • Tensorflow之Saver的用法详解

    Saver的用法 1. Saver的背景介绍 我们经常在训练完一个模型之后希望保存训练的结果,这些结果指的是模型的参数,以便下次迭代的训练或者用作测试.Tensorflow针对这一需求提供了Saver类. Saver类提供了向checkpoints文件保存和从checkpoints文件中恢复变量的相关方法.Checkpoints文件是一个二进制文件,它把变量名映射到对应的tensor值 . 只要提供一个计数器,当计数器触发时,Saver类可以自动的生成checkpoint文件.这让我们可以在训练

  • 对TensorFlow的assign赋值用法详解

    TensorFlow修改变量值后,需要重新赋值,assign用起来有点小技巧,就是需要需要弄个操作子,运行一下. 下面这么用是不行的 import tensorflow as tf import numpy as np x = tf.Variable(0) init = tf.initialize_all_variables() sess = tf.InteractiveSession() sess.run(init) print(x.eval()) x.assign(1) print(x.ev

  • Tensorflow中tf.ConfigProto()的用法详解

    参考Tensorflow Machine Leanrning Cookbook tf.ConfigProto()主要的作用是配置tf.Session的运算方式,比如gpu运算或者cpu运算 具体代码如下: import tensorflow as tf session_config = tf.ConfigProto( log_device_placement=True, inter_op_parallelism_threads=0, intra_op_parallelism_threads=0,

  • 基于Tensorflow一维卷积用法详解

    我就废话不多说了,大家还是直接看代码吧! import tensorflow as tf import numpy as np input = tf.constant(1,shape=(64,10,1),dtype=tf.float32,name='input')#shape=(batch,in_width,in_channels) w = tf.constant(3,shape=(3,1,32),dtype=tf.float32,name='w')#shape=(filter_width,in

  • 对TensorFlow中的variables_to_restore函数详解

    variables_to_restore函数,是TensorFlow为滑动平均值提供.之前,也介绍过通过使用滑动平均值可以让神经网络模型更加的健壮.我们也知道,其实在TensorFlow中,变量的滑动平均值都是由影子变量所维护的,如果你想要获取变量的滑动平均值需要获取的是影子变量而不是变量本身. 1.滑动平均值模型文件的保存 import tensorflow as tf if __name__ == "__main__": v = tf.Variable(0.,name="

  • Pytorch中膨胀卷积的用法详解

    卷积和膨胀卷积 在深度学习中,我们会碰到卷积的概念,我们知道卷积简单来理解就是累乘和累加,普通的卷积我们在此不做赘述,大家可以翻看相关书籍很好的理解. 最近在做项目过程中,碰到Pytorch中使用膨胀卷积的情况,想要的输入输出是图像经过四层膨胀卷积后图像的宽高尺寸不发生变化. 开始我的思路是padding='SAME'结合strides=1来实现输入输出尺寸不变,试列好多次还是有问题,报了张量错误的提示,想了好久也没找到解决方法,上网搜了下,有些人的博客说经过膨胀卷积之后图像的尺寸不发生变化,有

  • pytorch1.0中torch.nn.Conv2d用法详解

    Conv2d的简单使用 torch 包 nn 中 Conv2d 的用法与 tensorflow 中类似,但不完全一样. 在 torch 中,Conv2d 有几个基本的参数,分别是 in_channels 输入图像的深度 out_channels 输出图像的深度 kernel_size 卷积核大小,正方形卷积只为单个数字 stride 卷积步长,默认为1 padding 卷积是否造成尺寸丢失,1为不丢失 与tensorflow不一样的是,pytorch中的使用更加清晰化,我们可以使用这种方法定义输

  • Oracle中游标Cursor基本用法详解

    查询 SELECT语句用于从数据库中查询数据,当在PL/SQL中使用SELECT语句时,要与INTO子句一起使用,查询的 返回值被赋予INTO子句中的变量,变量的声明是在DELCARE中.SELECT INTO语法如下: SELECT [DISTICT|ALL]{*|column[,column,...]} INTO (variable[,variable,...] |record) FROM {table|(sub-query)}[alias] WHERE............ PL/SQL

  • JSP 中request与response的用法详解

    JSP 中request与response的用法详解 概要: 在学习这两个对象之前,我们应该已经有了http协议的基本了解了,如果不清楚http协议的可以看我的关于http协议的介绍.因为其实request和response的使用大部分都是对http协议的操作. request对象的介绍 我们先从request对象进行介绍: 我们知道http协议定义了请求服务器的格式: 请求行 请求头 空格 请求体(get请求没有请求体) 好了,这里我们就不详细介绍了,我们只看几个应用就可以了,没什么难度: 应

  • 基于C++中setiosflags()的用法详解

    cout<<setiosflags(ios::fixed)<<setiosflags(ios::right)<<setprecision(2); setiosflags 是包含在命名空间iomanip 中的C++ 操作符,该操作符的作用是执行由有参数指定区域内的动作:   iso::fixed 是操作符setiosflags 的参数之一,该参数指定的动作是以带小数点的形式表示浮点数,并且在允许的精度范围内尽可能的把数字移向小数点右侧:   iso::right 也是se

随机推荐