tensorflow实现从.ckpt文件中读取任意变量

思路有些混乱,希望大家能理解我的意思。

看了faster rcnn的tensorflow代码,关于fix_variables的作用我不是很明白,所以写了以下代码,读取了预训练模型vgg16得fc6和fc7的参数,以及faster rcnn中heat_to_tail中的fc6和fc7,将它们做了对比,发现结果不一样,说明vgg16的fc6和fc7只是初始化了faster rcnn中heat_to_tail中的fc6和fc7,之后后者被训练。

具体读取任意变量的代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.python import pywrap_tensorflow

file_name = '/home/dl/projectBo/tf-faster-rcnn/data/imagenet_weights/vgg16.ckpt' #.ckpt的路径
name_variable_to_restore = 'vgg_16/fc7/weights' #要读取权重的变量名
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()
print('shape', var_to_shape_map[name_variable_to_restore]) #输出这个变量的尺寸
fc7_conv = tf.get_variable("fc7", var_to_shape_map[name_variable_to_restore], trainable=False) # 定义接收权重的变量名
restorer_fc = tf.train.Saver({name_variable_to_restore: fc7_conv }) #定义恢复变量的对象
sess = tf.Session()
sess.run(tf.variables_initializer([fc7_conv], name='init')) #必须初始化
restorer_fc.restore(sess, file_name) #恢复变量
print(sess.run(fc7_conv)) #输出结果

用以上的代码分别读取两个网络的fc6 和 fc7 ,对应参数尺寸和权值都不同,但参数量相同。

再看lib/nets/vgg16.py中的:

(注意注释)

def fix_variables(self, sess, pretrained_model):
 print('Fix VGG16 layers..')
 with tf.variable_scope('Fix_VGG16') as scope:
  with tf.device("/cpu:0"):
   # fix the vgg16 issue from conv weights to fc weights
   # fix RGB to BGR
   fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)
   fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
   conv1_rgb = tf.get_variable("conv1_rgb", [3, 3, 3, 64], trainable=False)   #定义接收权重的变量,不可被训练
   restorer_fc = tf.train.Saver({self._scope + "/fc6/weights": fc6_conv,
                  self._scope + "/fc7/weights": fc7_conv,
                  self._scope + "/conv1/conv1_1/weights": conv1_rgb}) #定义恢复变量的对象
   restorer_fc.restore(sess, pretrained_model) #恢复这些变量

   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc6/weights:0'], tf.reshape(fc6_conv,
             self._variables_to_fix[self._scope + '/fc6/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc7/weights:0'], tf.reshape(fc7_conv,
             self._variables_to_fix[self._scope + '/fc7/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/conv1/conv1_1/weights:0'],
             tf.reverse(conv1_rgb, [2])))         #将vgg16中的fc6、fc7中的权重reshape赋给faster-rcnn中的fc6、fc7

我的理解:faster rcnn的网络继承了分类网络的特征提取权重和分类器的权重,让网络从一个比较好的起点开始被训练,有利于训练结果的快速收敛。

补充知识:TensorFlow:加载部分ckpt文件变量&不同命名空间中加载模型

TensorFlow中,在加载和保存模型时,一般会直接使用tf.train.Saver.restore()和tf.train.Saver.save()

然而,当需要选择性加载模型参数时,则需要利用pywrap_tensorflow读取模型,分析模型内的变量关系。

例子:Faster-RCNN中,模型加载vgg16.ckpt,需要利用pywrap_tensorflow读取ckpt文件中的参数

from tensorflow.python import pywrap_tensorflow

model=VGG16()#此处构建vgg16模型
variables = tf.global_variables()#获取模型中所有变量

file_name='vgg16.ckpt'#vgg16网络模型
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()#获取ckpt模型中的变量名
print(var_to_shape_map)

sess=tf.Session()

my_scope='my/'#外加的空间名
variables_to_restore={}#构建字典:需要的变量和对应的模型变量的映射
for v in variables:
  if my_scope in v.name and v.name.split(':')[0].split(my_scope)[1] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name.split(':0')[0][len(my_scope):]]=v
  elif v.name.split(':')[0] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name]=v

restorer=tf.train.Saver(variables_to_restore)#将需要加载的变量作为参数输入
restorer.restore(sess, file_name)

实际中,Faster RCNN中所构建的vgg16网络的fc6和fc7权重shape如下:

<tf.Variable 'my/vgg_16/fc6/weights:0' shape=(25088, 4096) dtype=float32_ref>,
<tf.Variable 'my/vgg_16/fc7/weights:0' shape=(4096, 4096) dtype=float32_ref>,

vgg16.ckpt的fc6,fc7权重shape如下:

'vgg_16/fc6/weights': [7, 7, 512, 4096],
'vgg_16/fc7/weights': [1, 1, 4096, 4096],

因此,有如下操作:

fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)
fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)

restorer_fc = tf.train.Saver({"vgg_16/fc6/weights": fc6_conv,
               "vgg_16/fc7/weights": fc7_conv,
               })
restorer_fc.restore(sess, pretrained_model)
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc6/weights:0'], tf.reshape(fc6_conv,self._variables_to_fix['my/vgg_16/fc6/weights:0'].get_shape())))
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc7/weights:0'], tf.reshape(fc7_conv,self._variables_to_fix['my/vgg_16/fc7/weights:0'].get_shape())))

以上这篇tensorflow实现从.ckpt文件中读取任意变量就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • tensorflow从ckpt和从.pb文件读取变量的值方式

    最近在学习tensorflow自带的量化工具的相关知识,其中遇到的一个问题是从tensorflow保存好的ckpt文件或者是保存后的.pb文件(这里的pb是把权重和模型保存在一起的pb文件)读取权重,查看量化后的权重是否变成整形. 因此将自己解决这个问题记录下来,为了下一次遇到时,可以有所参考,也希望给有需要的同学一个可能的参考. (1) 从保存的ckpt读取变量的值(以读取保存的第一个权重为例) from tensorflow.python import pywrap_tensorflow i

  • tensorflow之变量初始化(tf.Variable)使用详解

    默认本系列的的读者已经初步熟悉tensorflow. 我们通过tf.Variable构造一个variable添加进图中,Variable()构造函数需要变量的初始值(是一个任意类型.任意形状的tensor),这个初始值指定variable的类型和形状.通过Variable()构造函数后,此variable的类型和形状固定不能修改了,但值可以用assign方法修改. 如果想修改variable的shape,可以使用一个assign op,令validate_shape=False. 通过Varia

  • 对TensorFlow中的variables_to_restore函数详解

    variables_to_restore函数,是TensorFlow为滑动平均值提供.之前,也介绍过通过使用滑动平均值可以让神经网络模型更加的健壮.我们也知道,其实在TensorFlow中,变量的滑动平均值都是由影子变量所维护的,如果你想要获取变量的滑动平均值需要获取的是影子变量而不是变量本身. 1.滑动平均值模型文件的保存 import tensorflow as tf if __name__ == "__main__": v = tf.Variable(0.,name="

  • python 动态生成变量名以及动态获取变量的变量名方法

    前言 需求: 必须现在需要动态创建16个list,每个list的名字不一样,但是是有规律可循,比如第一个list的名字叫: arriage_list_0=[],第二个叫arriage_list_1=[]--..依次类推,但是我又不想手动的去写16个这样的名字,太累了,而且增加了代码的冗余性,灵活性也不强,所以有没有一种方法是能动态创建list名称的呢?答案是有的!而与之对应,既然要对上面的列表动态操作,肯定是少不了动态去解析list名称.所以下面开始介绍方法. python 动态生成变量名 lo

  • tensorflow实现从.ckpt文件中读取任意变量

    思路有些混乱,希望大家能理解我的意思. 看了faster rcnn的tensorflow代码,关于fix_variables的作用我不是很明白,所以写了以下代码,读取了预训练模型vgg16得fc6和fc7的参数,以及faster rcnn中heat_to_tail中的fc6和fc7,将它们做了对比,发现结果不一样,说明vgg16的fc6和fc7只是初始化了faster rcnn中heat_to_tail中的fc6和fc7,之后后者被训练. 具体读取任意变量的代码如下: import tensor

  • Tensorflow 如何从checkpoint文件中加载变量名和变量值

    假设你已经经过上千次的迭代,并且得到了以下模型: 则从这些checkpoint文件中加载变量名和变量值代码如下: model_dir = './ckpt-182802' import tensorflow as tf from tensorflow.python import pywrap_tensorflow reader = pywrap_tensorflow.NewCheckpointReader(model_dir) var_to_shape_map = reader.get_varia

  • Tensorflow: 从checkpoint文件中读取tensor方式

    在使用pre-train model时候,我们需要restore variables from checkpoint files. 经常出现在checkpoint 中找不到"Tensor name not found". 这时候需要查看一下ckpt中到底有哪些变量 import os from tensorflow.python import pywrap_tensorflow checkpoint_path = os.path.join(model_dir, "model.

  • PHP如何从txt文件中读取数据详解

    目录 一.打开/关闭文件 二.读写文件 1.读取整个文件 2.读取一行数据 3.读取一个字符 4.读取任意长度的字符串 总结 一.打开/关闭文件 1.对文件操作时首先要打开文件,打开文件用 fopen()函数,语法是: fopen(filename,mode,include_path,context); 2.对文件操作结束后应该关闭这个文件,使用函数 fclose(); 例如: 二.读写文件 1.读取整个文件 有三个函数可以使用,分别是:readfile()函数.file()函数.file_ge

  • Perl从文件中读取字符串的两种实现方法

    1. 一次性将文件中的所有内容读入一个数组中(该方法适合小文件): 复制代码 代码如下: open(FILE,"filename")||die"can not open the file: $!";@filelist=<FILE>; foreach $eachline (@filelist) {        chomp $eachline;}close FILE;@filelist=<FILE>; 当文件很大时,可能会出现"out

  • Python3实现从文件中读取指定行的方法

    本文实例讲述了Python3实现从文件中读取指定行的方法.分享给大家供大家参考.具体实现方法如下: # Python的标准库linecache模块非常适合这个任务 import linecache the_line = linecache.getline('d:/FreakOut.cpp', 222) print (the_line) # linecache读取并缓存文件中所有的文本, # 若文件很大,而只读一行,则效率低下. # 可显示使用循环, 注意enumerate从0开始计数,而line

  • Python3实现将文件归档到zip文件及从zip文件中读取数据的方法

    本文实例讲述了Python3实现将文件归档到zip文件及从zip文件中读取数据的方法.分享给大家供大家参考.具体实现方法如下: ''''' Created on Dec 24, 2012 将文件归档到zip文件,并从zip文件中读取数据 @author: liury_lab ''' # 压缩成zip文件 from zipfile import * #@UnusedWildImport import os my_dir = 'd:/中华十大名帖/' myzip = ZipFile('d:/中华十大

  • android从资源文件中读取文件流并显示的方法

    本文实例讲述了android从资源文件中读取文件流并显示的方法.分享给大家供大家参考.具体如下: 在android中,假如有的文本文件,比如TXT放在raw下,要直接读取出来,放到屏幕中显示,可以这样: private void doRaw(){ InputStream is = this.getResources().openRawResource(R.raw.ziliao); try{ doRead(is); }catch(IOException e){ e.printStackTrace(

  • Java将对象保存到文件中/从文件中读取对象的方法

    1.保存对象到文件中 Java语言只能将实现了Serializable接口的类的对象保存到文件中,利用如下方法即可: public static void writeObjectToFile(Object obj) { File file =new File("test.dat"); FileOutputStream out; try { out = new FileOutputStream(file); ObjectOutputStream objOut=new ObjectOutp

  • 从Java的jar文件中读取数据的方法

    本文实例讲述了从Java的jar文件中读取数据的方法.分享给大家供大家参考.具体如下: Java 档案 (Java Archive, JAR) 文件是基于 Java 技术的打包方案.它们允许开发人员把所有相关的内容 (.class.图片.声音和支持文件等) 打包到一个单一的文件中.JAR 文件格式支持压缩.身份验证和版本,以及许多其它特性. 从 JAR 文件中得到它所包含的文件内容是件棘手的事情,但也不是不可以做到.这篇技巧就将告诉你如何从 JAR 文件中取得一个文件.我们会先取得这个 JAR

随机推荐