TensorFlow固化模型的实现操作

前言

TensorFlow目前在移动端是无法training的,只能跑已经训练好的模型,但一般的保存方式只有单一保存参数或者graph的,如何将参数、graph同时保存呢?

生成模型

主要有两种方法生成模型,一种是通过freeze_graph把tf.train.write_graph()生成的pb文件与tf.train.saver()生成的chkp文件固化之后重新生成一个pb文件,这一种现在不太建议使用。另一种是把变量转成常量之后写入PB文件中。我们简单的介绍下freeze_graph方法。

freeze_graph

这种方法我们需要先使用tf.train.write_graph()以及tf.train.saver()生成pb文件和ckpt文件,代码如下:

with tf.Session() as sess:
 saver = tf.train.Saver()
 saver.save(session, "model.ckpt")
 tf.train.write_graph(session.graph_def, '', 'graph.pb')

然后使用TensorFlow源码中的freeze_graph工具进行固化操作:

首先需要build freeze_graph 工具( 需要 bazel ):

bazel build tensorflow/python/tools:freeze_graph

然后使用这个工具进行固化(/path/to/表示文件路径):

bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/path/to/graph.pb --input_checkpoint=/path/to/model.ckpt --output_node_names=output/predict --output_graph=/path/to/frozen.pb
convert_variables_to_constants

其实在TensorFlow中传统的保存模型方式是保存常量以及graph的,而我们的权重主要是变量,如果我们把训练好的权重变成常量之后再保存成PB文件,这样确实可以保存权重,就是方法有点繁琐,需要一个一个调用eval方法获取值之后赋值,再构建一个graph,把W和b赋值给新的graph。

牛逼的Google为了方便大家使用,编写了一个方法供我们快速的转换并保存。

首先我们需要引入这个方法

from tensorflow.python.framework.graph_util import convert_variables_to_constants

在想要保存的地方加入如下代码,把变量转换成常量

output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output/predict'])

这里参数第一个是当前的session,第二个为graph,第三个是输出节点名(如我的输出层代码是这样的:)

 with tf.name_scope('output'):
 w_out = tf.Variable(w_alpha * tf.random_normal([1024, MAX_CAPTCHA * CHAR_SET_LEN]))
 tf.summary.histogram('output/weight', w_out)
 b_out = tf.Variable(b_alpha * tf.random_normal([MAX_CAPTCHA * CHAR_SET_LEN]))
 tf.summary.histogram('output/biases', b_out)
 out = tf.add(tf.matmul(dense2, w_out), b_out)
 out = tf.nn.softmax(out)
 predict = tf.argmax(tf.reshape(out, [-1, 11, 36]), 2, name='predict')

由于我们采用了name_scope所以我们在predict之前需要加上output/

生成文件

with tf.gfile.FastGFile('model/CTNModel.pb', mode='wb') as f:
f.write(output_graph_def.SerializeToString())

第一个参数是文件路径,第二个是指文件操作的模式,这里指的是以二进制的方式写入文件。

运行代码,系统会生成一个PB文件,接下来我们要测试下这个模型是否能够正常的读取、运行。

测试模型

在Python环境下,我们首先需要加载这个模型,代码如下:

with open('./model/rounded_graph.pb', 'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 output = tf.import_graph_def(graph_def,
     input_map={'inputs/X:0': newInput_X},
     return_elements=['output/predict:0'])

由于我们原本的网络输入值是一个placeholder,这里为了方便输入我们也先定义一个新的placeholder:

newInput_X = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT * IMAGE_WIDTH], name="X")

在input_map的参数填入新的placeholder。

在调用我们的网络的时候直接用这个新的placeholder接收数据,如:

text_list = sesss.run(output, feed_dict={newInput_X: [captcha_image]})

然后就是运行我们的网络,看是否可以运行吧。

以上这篇TensorFlow固化模型的实现操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • 将keras的h5模型转换为tensorflow的pb模型操作

    背景:目前keras框架使用简单,很容易上手,深得广大算法工程师的喜爱,但是当部署到客户端时,可能会出现各种各样的bug,甚至不支持使用keras,本文来解决的是将keras的h5模型转换为客户端常用的tensorflow的pb模型并使用tensorflow加载pb模型. h5_to_pb.py from keras.models import load_model import tensorflow as tf import os import os.path as osp from kera

  • tensorflow2.0保存和恢复模型3种方法

    方法1:只保存模型的权重和偏置 这种方法不会保存整个网络的结构,只是保存模型的权重和偏置,所以在后期恢复模型之前,必须手动创建和之前模型一模一样的模型,以保证权重和偏置的维度和保存之前的相同. tf.keras.model类中的save_weights方法和load_weights方法,参数解释我就直接搬运官网的内容了. save_weights( filepath, overwrite=True, save_format=None ) Arguments: filepath: String,

  • Tensorflow中的图(tf.Graph)和会话(tf.Session)的实现

    Tensorflow编程系统 Tensorflow工具或者说深度学习本身就是一个连贯紧密的系统.一般的系统是一个自治独立的.能实现复杂功能的整体.系统的主要任务是对输入进行处理,以得到想要的输出结果.我们之前见过的很多系统都是线性的,就像汽车生产工厂的流水线一样,输入->系统处理->输出.系统内部由很多单一的基本部件构成,这些单一部件具有特定的功能,且需要稳定的特性:系统设计者通过特殊的连接方式,让这些简单部件进行连接,以使它们之间可以进行数据交流和信息互换,来达到相互配合而完成具体工作的目的

  • TensorFlow固化模型的实现操作

    前言 TensorFlow目前在移动端是无法training的,只能跑已经训练好的模型,但一般的保存方式只有单一保存参数或者graph的,如何将参数.graph同时保存呢? 生成模型 主要有两种方法生成模型,一种是通过freeze_graph把tf.train.write_graph()生成的pb文件与tf.train.saver()生成的chkp文件固化之后重新生成一个pb文件,这一种现在不太建议使用.另一种是把变量转成常量之后写入PB文件中.我们简单的介绍下freeze_graph方法. f

  • 打印tensorflow恢复模型中所有变量与操作节点方式

    我就废话不多说了,大家还是直接看代码吧! #参数恢复 self.sess=tf.Session() saver = tf.train.import_meta_graph(os.path.join(model_fullpath,'model.ckpt-7.meta')) module_file = tf.train.latest_checkpoint(model_fullpath) saver.restore(self.sess, module_file) variable_names = [v.

  • 对tensorflow 的模型保存和调用实例讲解

    我们通常采用tensorflow来训练,训练完之后应当保存模型,即保存模型的记忆(权重和偏置),这样就可以来进行人脸识别或语音识别了. 1.模型的保存 # 声明两个变量 v1 = tf.Variable(tf.random_normal([1, 2]), name="v1") v2 = tf.Variable(tf.random_normal([2, 3]), name="v2") init_op = tf.global_variables_initializer(

  • keras和tensorflow使用fit_generator 批次训练操作

    fit_generator 是 keras 提供的用来进行批次训练的函数,使用方法如下: model.fit_generator(generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=F

  • 将TensorFlow的模型网络导出为单个文件的方法

    有时候,我们需要将TensorFlow的模型导出为单个文件(同时包含模型架构定义与权重),方便在其他地方使用(如在c++中部署网络).利用tf.train.write_graph()默认情况下只导出了网络的定义(没有权重),而利用tf.train.Saver().save()导出的文件graph_def与权重是分离的,因此需要采用别的方法. 我们知道,graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果我们能把Variable

  • 解决tensorflow测试模型时NotFoundError错误的问题

    错误代码如下: NotFoundError (see above for traceback): Unsuccessful TensorSliceReader constructor: Failed to find any matching files for xxx -- 经查资料分析,错误原因可能出在加载模型时的路径问题.我采用的加载模型方法: with tf.Session() as sess: print("Reading checkpoints...") ckpt = tf.

  • tensorflow 获取模型所有参数总和数量的方法

    实例如下所示: from functools import reduce from operator import mul def get_num_params(): num_params = 0 for variable in tf.trainable_variables(): shape = variable.get_shape() num_params += reduce(mul, [dim.value for dim in shape], 1) return num_params 以上这

  • Laravel框架模型的创建及模型对数据操作示例

    本文实例讲述了Laravel框架模型的创建及模型对数据操作.分享给大家供大家参考,具体如下: 模型创建: <?php namespace App; use Illuminate\Database\Eloquent\Model; class Admin extends Model{ //指定表名 protected $table = 'wd_user'; //指定允许批量复制的字段 protected $fillable = ['username']; //指定id protected $prim

  • Laravel5.1 框架模型软删除操作实例分析

    本文实例讲述了Laravel5.1 框架模型软删除操作.分享给大家供大家参考,具体如下: 软删除是比较实用的一种删除手段,比如说 你有一本账 有一笔记录你觉得不对给删了 过了几天发现不应该删除,这时候软删除的目的就实现了 你可以找到已经被删除的数据进行操作 可以是还原也可以是真正的删除. 1 普通删除 在软删除之前咱先看看普通的删除方法: 1.1 直接通过主键删除 public function getDelete() { Article::destroy(1); Article::destro

  • tensorflow通过模型文件,使用tensorboard查看其模型图Graph方式

    Google提供了一个工具,TensorBoard,它能以图表的方式分析你在训练过程中汇总的各种数据,其中包括Graph结构. 所以我们可以简单的写几行Pyhton,加载Graph,只在logdir里,输出Graph结构数据,并可以查看其图结构. 执行下述代码,将数据流图保存为图片,在目录F:/tensorflow/graph下生成文件events.out.tfevents.1508420019.XM-PC import tensorflow as tf from tensorflow.pyth

随机推荐