tensorflow 加载部分变量的实例讲解

tensorflow模型保存为saver = tf.train.Saver()函数,saver.save()保存模型,代码如下:

import tensorflow as tf

v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
saver = tf.train.Saver()
with tf.Session() as sess:
 init_op = tf.global_variables_initializer()
 sess.run(init_op)
 saver.save(sess,"checkpoint/model_test",global_step=1)

当我们保存模型后,我们可以通过saver.restore()来加载模型,初始化变量:

import tensorflow as tf

v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
saver = tf.train.Saver()
with tf.Session() as sess:
 # init_op = tf.global_variables_initializer()
 # sess.run(init_op)
 saver.restore(sess, "checkpoint/model_test-1")
 # saver.save(sess,"checkpoint/model_test",global_step=1)

神经网络训练时,有时候我们需要从预训练的模型中加载部分参数,初始化当前模型,例如加入CNN有6层,我们需要从已有的模型初始化CNN前5层参数.这可以通过saver.restore()实现.

之前我们已经介绍可以通过tf.train.Saver()的保存部分变量的方法,即需要保存的变量列表,同样的,在变量初始化的时候,我们可以对需要单独初始化的变量分别定义一个tf.train.Saver()函数,这样就可以单独对该部分变量初始化,例如下面代码,saver1用于初始化变量v1,saver2用于初始化变量v2,v3:

import tensorflow as tf

v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
v3= tf.Variable(tf.zeros([100]), name="v3")
#saver = tf.train.Saver()
saver1 = tf.train.Saver([v1])
saver2 = tf.train.Saver([v2]+[v3])
with tf.Session() as sess:
 # init_op = tf.global_variables_initializer()
 # sess.run(init_op)
 saver1.restore(sess, "checkpoint/model_test-1")
 saver2.restore(sess, "checkpoint/model_test-1")
 # saver.save(sess,"checkpoint/model_test",global_step=1)

以上这篇tensorflow 加载部分变量的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • TensorFlow saver指定变量的存取

    今天和大家分享一下用TensorFlow的saver存取训练好的模型那点事. 1. 用saver存取变量: 2. 用saver存取指定变量. 用saver存取变量. 话不多说,先上代码 # coding=utf-8 import os import tensorflow as tf import numpy os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #有些指令集没有装,加这个不显示那些警告 w = tf.Variable([[1,2,3],[2,3,4],

  • 基于tensorflow加载部分层的方法

    一般使用 saver.restore(sess, modeldir + "model.ckpt") 即可加载已经训练好的网络,可是有时候想值使用部分层的参数,这时候可以选择在加载网络之后重新初始化剩下的层 var_list = [weights['wd1'], weights['out'], biases['bd1'], biases['out'], global_step] initfc = tf.variables_initializer(var_list, name='init'

  • TensorFlow变量管理详解

    一.TensorFlow变量管理 1. TensorFLow还提供了tf.get_variable函数来创建或者获取变量,tf.variable用于创建变量时,其功能和tf.Variable基本是等价的.tf.get_variable中的初始化方法(initializer)的参数和tf.Variable的初始化过程也类似,initializer函数和tf.Variable的初始化方法是一一对应的,详见下表. tf.get_variable和tf.Variable最大的区别就在于指定变量名称的参数

  • 详解TensorFlow查看ckpt中变量的几种方法

    查看TensorFlow中checkpoint内变量的几种方法 查看ckpt中变量的方法有三种: 在有model的情况下,使用tf.train.Saver进行restore 使用tf.train.NewCheckpointReader直接读取ckpt文件,这种方法不需要model. 使用tools里的freeze_graph来读取ckpt 注意: 如果模型保存为.ckpt的文件,则使用该文件就可以查看.ckpt文件里的变量.ckpt路径为 model.ckpt 如果模型保存为.ckpt-xxx-

  • tensorflow 加载部分变量的实例讲解

    tensorflow模型保存为saver = tf.train.Saver()函数,saver.save()保存模型,代码如下: import tensorflow as tf v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1") v2= tf.Variable(tf.zeros([200]), name="v2") saver = tf.train.Saver() with tf

  • JQuery Ajax动态加载Table数据的实例讲解

    我们在jsp定义一个select和一个table,要求实现根据select的选值,动态加载table数据. <select id="type" name="type" onchange="reloadTable(this)"></select> <table id="import-table" class="table table-striped table-bordered table

  • Android:下拉刷新+加载更多+滑动删除实例讲解

    小伙伴们在逛淘宝或者是各种app上,都可以看到这样的功能,下拉刷新和加载更多以及滑动删除,刷新,指刷洗之后使之变新,比喻突破旧的而创造出新的,比如在手机上浏览新闻的时候,使用下拉刷新的功能,我们可以第一时间掌握最新消息,加载更多是什么nie,简单来说就是在网页上逛淘宝的时候,我们可以点击下一页来满足我们更多的需求,但是在手机端就不一样了,没有上下页,怎么办nie,方法总比困难多,细心的小伙伴可能会发现,在手机端中,有加载更多来满足我们的要求,其实加载更多也是分页的一种体现.小伙伴在使用手机版QQ

  • jquery zTree异步加载、模糊搜索简单实例分享

    本文实例为大家讲解了jquery zTree树插件的基本使用方法,具体内容如下 一.节点模糊搜索功能:搜索成功后,自动高亮显示并定位.展开搜索到的节点. 二.节点异步加载:1.点击展开时加载数据:2.选中节点时加载数据. 前台代码如下: <script type="text/javascript"> //ztree设置 var setting = { view: { fontCss: getFontCss }, check: { enable: true }, data:

  • tensorflow模型保存、加载之变量重命名实例

    话不多说,干就完了. 变量重命名的用处? 简单定义:简单来说就是将模型A中的参数parameter_A赋给模型B中的parameter_B 使用场景:当需要使用已经训练好的模型参数,尤其是使用别人训练好的模型参数时,往往别人模型中的参数命名方式与自己当前的命名方式不同,所以在加载模型参数时需要对参数进行重命名,使得代码更简洁易懂. 实现方法: 1).模型保存 import os import tensorflow as tf weights = tf.Variable(initial_value

  • Tensorflow加载预训练模型和保存模型的实例

    使用tensorflow过程中,训练结束后我们需要用到模型文件.有时候,我们可能也需要用到别人训练好的模型,并在这个基础上再次训练.这时候我们需要掌握如何操作这些模型数据.看完本文,相信你一定会有收获! 1 Tensorflow模型文件 我们在checkpoint_dir目录下保存的文件结构如下: |--checkpoint_dir | |--checkpoint | |--MyModel.meta | |--MyModel.data-00000-of-00001 | |--MyModel.in

  • Laravel框架模板加载,分配变量及简单路由功能示例

    本文实例讲述了Laravel框架模板加载,分配变量及简单路由功能.分享给大家供大家参考,具体如下: 作为世界上第一的PHP框架,学习Laraver势在必行,虽然国内盛行ThinkPHP,但是多会一个框架总是对自己有好处的. 通过前面的文章Laravel框架在本地虚拟机快速安装的方法,我们已经可以顺利安装Laravel 安装之后,在目录laravel\app\Http下,有一个routes.php文件,重点了,这个就是控制全站的路由文件. Route::get('/', function () {

  • 对tensorflow 的模型保存和调用实例讲解

    我们通常采用tensorflow来训练,训练完之后应当保存模型,即保存模型的记忆(权重和偏置),这样就可以来进行人脸识别或语音识别了. 1.模型的保存 # 声明两个变量 v1 = tf.Variable(tf.random_normal([1, 2]), name="v1") v2 = tf.Variable(tf.random_normal([2, 3]), name="v2") init_op = tf.global_variables_initializer(

  • Python 在字符串中加入变量的实例讲解

    有时候,我们需要在字符串中加入相应的变量,以下提供了几种字符串加入变量的方法: 1.+ 连字符 name = 'zhangsan' print('my name is '+name) #结果为 my name is zhangsan 2.% 字符 name = 'zhangsan' age = 25 price = 4500.225 print('my name is %s'%(name)) print('i am %d'%(age)+' years old') print('my price

随机推荐