Tensorflow 使用pb文件保存(恢复)模型计算图和参数实例详解

一、保存:

graph_util.convert_variables_to_constants 可以把当前session的计算图串行化成一个字节流(二进制),这个函数包含三个参数:参数1:当前活动的session,它含有各变量

参数2:GraphDef 对象,它描述了计算网络

参数3:Graph图中需要输出的节点的名称的列表

返回值:精简版的GraphDef 对象,包含了原始输入GraphDef和session的网络和变量信息,它的成员函数SerializeToString()可以把这些信息串行化为字节流,然后写入文件里:

constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] )
with open( pbName, mode='wb') as f:
f.write(constant_graph.SerializeToString())

需要指出的是,如果原始张量(包含在参数1和参数2中的组成部分)不参与参数3指定的输出节点列表所指定的张量计算的话,这些张量将不会存在返回的GraphDef对象里,也不会被串行化写入pb文件。

二、恢复:

恢复时,创建一个GraphDef,然后从上述的文件里加载进来,接着输入到当前的session:

    graph0 = tf.GraphDef()
    with open( pbName, mode='rb') as f:
      graph0.ParseFromString( f.read() )
      tf.import_graph_def( graph0 , name = '' )

三、代码:


import tensorflow as tf
from tensorflow.python.framework import graph_util

pbName = 'graphA.pb'
def graphCreate() :
  with tf.Session() as sess :
    var1 = tf.placeholder ( tf.int32 , name='var1' )
    var2 = tf.Variable( 20 , name='var2' )#实参name='var2'指定了操作名,该操作返回的张量名是在
                       #'var2'后面:0 ,即var2:0 是返回的张量名,也就是说变量
                       # var2的名称是'var2:0'
    var3 = tf.Variable( 30 , name='var3' )
    var4 = tf.Variable( 40 , name='var4' )
    var4op = tf.assign( var4 , 1000 , name = 'var4op1' )
    sum = tf.Variable( 4, name='sum' )
    sum = tf.add ( var1 , var2, name = 'var1_var2' )
    sum = tf.add( sum , var3 , name='sum_var3' )
    sumOps = tf.add( sum , var4 , name='sum_operation' )
    oper = tf.get_default_graph().get_operations()
    with open( 'operation.csv','wt' ) as f:
      s = 'name,type,output\n'
      f.write( s )
      for o in oper:
        s = o.name
        s += ','+ o.type
        inp = o.inputs
        oup = o.outputs
        for iip in inp :
          s #s += ','+ str(iip)
        for iop in oup :
          s += ',' + str(iop)
        s += '\n'
        f.write( s ) 

      for var in tf.global_variables():
        print('variable=> ' , var.name) #张量是tf.Variable/tf.Add之类操作的结果,
                        #张量的名字使用操作名加:0来表示
    init = tf.global_variables_initializer()
    sess.run( init )
    sess.run( var4op )
    print('sum_operation result is Tensor ' , sess.run( sumOps , feed_dict={var1:1}) )

    constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] )
    with open( pbName, mode='wb') as f:
      f.write(constant_graph.SerializeToString())

def graphGet() :
  print("start get:" )
  with tf.Graph().as_default():
    graph0 = tf.GraphDef()
    with open( pbName, mode='rb') as f:
      graph0.ParseFromString( f.read() )
      tf.import_graph_def( graph0 , name = '' )
    with tf.Session() as sess :
      init = tf.global_variables_initializer()
      sess.run(init)
      v1 = sess.graph.get_tensor_by_name('var1:0' )
      v2 = sess.graph.get_tensor_by_name('var2:0' )
      v3 = sess.graph.get_tensor_by_name('var3:0' )
      v4 = sess.graph.get_tensor_by_name('var4:0' )

      sumTensor = sess.graph.get_tensor_by_name("sum_operation:0")
      print('sumTensor is : ' , sumTensor )
      print( sess.run( sumTensor , feed_dict={v1:1} ) ) 

graphCreate()
graphGet()

四、保存pb函数代码里的操作名称/类型/返回的张量:

operation name operation type output
var1 Placeholder Tensor("var1:0" dtype=int32)
var2/initial_value Const Tensor("var2/initial_value:0" shape=() dtype=int32)
var2 VariableV2 Tensor("var2:0" shape=() dtype=int32_ref)
var2/Assign Assign Tensor("var2/Assign:0" shape=() dtype=int32_ref)
var2/read Identity Tensor("var2/read:0" shape=() dtype=int32)
var3/initial_value Const Tensor("var3/initial_value:0" shape=() dtype=int32)
var3 VariableV2 Tensor("var3:0" shape=() dtype=int32_ref)
var3/Assign Assign Tensor("var3/Assign:0" shape=() dtype=int32_ref)
var3/read Identity Tensor("var3/read:0" shape=() dtype=int32)
var4/initial_value Const Tensor("var4/initial_value:0" shape=() dtype=int32)
var4 VariableV2 Tensor("var4:0" shape=() dtype=int32_ref)
var4/Assign Assign Tensor("var4/Assign:0" shape=() dtype=int32_ref)
var4/read Identity Tensor("var4/read:0" shape=() dtype=int32)
var4op1/value Const Tensor("var4op1/value:0" shape=() dtype=int32)
var4op1 Assign Tensor("var4op1:0" shape=() dtype=int32_ref)
sum/initial_value Const Tensor("sum/initial_value:0" shape=() dtype=int32)
sum VariableV2 Tensor("sum:0" shape=() dtype=int32_ref)
sum/Assign Assign Tensor("sum/Assign:0" shape=() dtype=int32_ref)
sum/read Identity Tensor("sum/read:0" shape=() dtype=int32)
var1_var2 Add Tensor("var1_var2:0" dtype=int32)
sum_var3 Add Tensor("sum_var3:0" dtype=int32)
sum_operation Add Tensor("sum_operation:0" dtype=int32)

以上这篇Tensorflow 使用pb文件保存(恢复)模型计算图和参数实例详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • TensorFlow实现checkpoint文件转换为pb文件

    由于项目需要,需要将TensorFlow保存的模型从ckpt文件转换为pb文件. import os from tensorflow.python import pywrap_tensorflow from net2use import inception_resnet_v2_small#这里使用自己定义的模型函数即可 import tensorflow as tf if __name__=='__main__': pb_file = "./model/output.pb" ckpt_

  • TensorFlow实现保存训练模型为pd文件并恢复

    TensorFlow保存模型代码 import tensorflow as tf from tensorflow.python.framework import graph_util var1 = tf.Variable(1.0, dtype=tf.float32, name='v1') var2 = tf.Variable(2.0, dtype=tf.float32, name='v2') var3 = tf.Variable(2.0, dtype=tf.float32, name='v3')

  • 解决tensorflow打印tensor有省略号的问题

    先上代码: import tensorflow as tf x = tf.ones(shape=[100, 200], dtype=tf.int32, name='x') y = tf.zeros(shape=[2, 3], dtype=tf.float32, name='y') with tf.Session() as sess: print(sess.run([x, y])) 输出结果如下: 如果我调试的时候想查看省略号代表的值是什么 只需要改成如下代码就行: import tensorfl

  • Tensorflow 使用pb文件保存(恢复)模型计算图和参数实例详解

    一.保存: graph_util.convert_variables_to_constants 可以把当前session的计算图串行化成一个字节流(二进制),这个函数包含三个参数:参数1:当前活动的session,它含有各变量 参数2:GraphDef 对象,它描述了计算网络 参数3:Graph图中需要输出的节点的名称的列表 返回值:精简版的GraphDef 对象,包含了原始输入GraphDef和session的网络和变量信息,它的成员函数SerializeToString()可以把这些信息串行

  • Java 将文件转为字节数组知识总结及实例详解

    Java将文件转为字节数组 关键字:文件,文件流,字节流,字节数组,二进制 摘要:最近工作中碰到的需求是,利用http传输二进制数据到服务器对应接口,需要传输userId, file(加密后)等一系列混合后的二进制数据.本文旨在记录自己在使用Java将文件转为字节数组的一些知识理解与汇总. FileInputStream 利用FileInputStream读取文件 FileInputStream是InputStream的子类,用于从文件中读取信息,构造器接收一个File类型或表示文件路径的Str

  • Android OkHttp Post上传文件并且携带参数实例详解

    Android OkHttp Post上传文件并且携带参数 这里整理一下 OkHttp 的 post 在上传文件的同时,也要携带请求参数的方法. 使用 OkHttp 版本如下: compile 'com.squareup.okhttp3:okhttp:3.4.1' 代码如下: protected void post_file(final String url, final Map<String, Object> map, File file) { OkHttpClient client = n

  • 生产者消费者模型ThreadLocal原理及实例详解

    1.生产者消费者模型作用和示例如下: 1)通过平衡生产者的生产能力和消费者的消费能力来提升整个系统的运行效率 ,这是生产者消费者模型最重要的作用 2)解耦,这是生产者消费者模型附带的作用,解耦意味着生产者和消费者之间的联系少,联系越少越可以独自发展而不需要收到相互的制约 备注:对于生产者消费者模型的理解将在并发队列BlockingQueue章节进行说明,本章不做详细介绍. package threadLearning.productCustomerModel; /* wait/notify 机制

  • tensorflow的ckpt及pb模型持久化方式及转化详解

    使用tensorflow训练模型的时候,模型持久化对我们来说非常重要. 如果我们的模型比较复杂,需要的数据比较多,那么在模型的训练时间会耗时很长.如果在训练过程中出现了模型不可预期的错误,导致训练意外终止,那么我们将会前功尽弃.为了解决这一问题,我们可以使用模型持久化(保存为ckpt文件格式)来保存我们在训练过程中的临时数据.. 如果我们训练出的模型需要提供给用户做离线预测,那么我们只需要完成前向传播过程.这个时候我们就可以使用模型持久化(保存为pb文件格式)来只保存前向传播过程中的变量并将变量

  • Tensorflow加载模型实现图像分类识别流程详解

    目录 前言 正文 VGG19网络介绍 总结 前言 深度学习框架在市面上有很多.比如Theano.Caffe.CNTK.MXnet .Tensorflow等.今天讲解的就是主角Tensorflow.Tensorflow的前身是Google大脑项目的一个分布式机器学习训练框架,它是一个十分基础且集成度很高的系统,它的目标就是为研究超大型规模的视觉项目,后面延申到各个领域.Tensorflow 在2015年正式开源,开源的一个月内就收获到1w多的starts,这足以说明Tensorflow的优越性以及

  • 对Django中内置的User模型实例详解

    User模型 User模型是这个框架的核心部分.他的完整的路径是在django.contrib.auth.models.User. 字段 内置的User模型拥有以下的字段: 1.username: 用户名.150个字符以内.可以包含数字和英文字符,以及_.@.+..和-字符.不能为空,且必须唯一! 2.first_name:歪果仁的first_name,在30个字符以内.可以为空. 3.last_name:歪果仁的last_name,在150个字符以内.可以为空. 4.email:邮箱.可以为空

  • python 实现tar文件压缩解压的实例详解

    python 实现tar文件压缩解压的实例详解 压缩文件: import tarfile import os def tar(fname): t = tarfile.open(fname + ".tar.gz", "w:gz") for root, dir, files in os.walk(fname): print root, dir, files for file in files: fullpath = os.path.join(root, file) t.

  • jsp按格式导出doc文件实例详解

    jsp按格式导出doc文件实例详解 原理:doc文件其实可以保存为xml文件,该xml文件用字符串表示了doc文件的表现形式,我们只需要用Java将那些要填的内容替换掉然后下载给客户就行了. 1.首先是按照你的文档填写好数据. 2.将文档另存为xml文件,然后编辑该xml文件,将填好的内容用某种格式替换,如:将名字张三替换成${name} 3.读取文件,将文件中的${name}替换成真正的名字. 4.下载. 接下来看代码: 首先是那个转换类 package com.my.util; import

  • JavaWeb实现文件上传与下载实例详解

    在Web应用程序开发中,文件上传与下载功能是非常常用的功能,下面通过本文给大家介绍JavaWeb实现文件上传与下载实例详解. 对于文件上传,浏览器在上传的过程中是将文件以流的形式提交到服务器端的,如果直接使用Servlet获取上传文件的输入流然后再解析里面的请求参数是比较麻烦,所以一般选择采用apache的开源工具common-fileupload这个文件上传组件.这个common-fileupload上传组件的jar包可以去apache官网上面下载,common-fileupload是依赖于c

随机推荐