python使用tensorflow保存、加载和使用模型的方法

使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用。介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍:

http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

我对这篇文章进行了整理和汇总。

首先是模型的保存。直接上代码:

#!/usr/bin/env python
#-*- coding:utf-8 -*-
############################
#File Name: tut1_save.py
#Author: Wang
#Mail: wang19920419@hotmail.com
#Created Time:2017-08-30 11:04:25
############################ 

import tensorflow as tf 

# prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.Variable(tf.random_normal(shape = [2]), name = 'w1') # name is very important in restoration
w2 = tf.Variable(tf.random_normal(shape = [2]), name = 'w2')
b1 = tf.Variable(2.0, name = 'bias1')
feed_dict = {w1:[10,3], w2:[5,5]} 

# define a test operation that will be restored
w3 = tf.add(w1, w2) # without name, w3 will not be stored
w4 = tf.multiply(w3, b1, name = "op_to_restore") 

#saver = tf.train.Saver()
saver = tf.train.Saver(max_to_keep = 4, keep_checkpoint_every_n_hours = 1)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print sess.run(w4, feed_dict)
#saver.save(sess, 'my_test_model', global_step = 100)
saver.save(sess, 'my_test_model')
#saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False)

需要说明的有以下几点:

1. 创建saver的时候可以指明要存储的tensor,如果不指明,就会全部存下来。在这里也可以指明最大存储数量和checkpoint的记录时间。具体细节看英文博客。

2. saver.save()函数里面可以设定global_step和write_meta_graph,meta存储的是网络结构,只在开始运行程序的时候存储一次即可,后续可以通过设置write_meta_graph = False加以限制。

3. 这个程序执行结束后,会在程序目录下生成四个文件,分别是.meta(存储网络结构)、.data和.index(存储训练好的参数)、checkpoint(记录最新的模型)。

下面是如何加载已经保存的网络模型。这里有两种方法,第一种是saver.restore(sess, 'aaaa.ckpt'),这种方法的本质是读取全部参数,并加载到已经定义好的网络结构上,因此相当于给网络的weights和biases赋值并执行tf.global_variables_initializer()。这种方法的缺点是使用前必须重写网络结构,而且网络结构要和保存的参数完全对上。第二种就比较高端了,直接把网络结构加载进来(.meta),上代码:

#!/usr/bin/env python
#-*- coding:utf-8 -*-
############################
#File Name: tut2_import.py
#Author: Wang
#Mail: wang19920419@hotmail.com
#Created Time:2017-08-30 14:16:38
############################
import tensorflow as tf
sess = tf.Session()
new_saver = tf.train.import_meta_graph('my_test_model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
print sess.run('w1:0') 

使用加载的模型,输入新数据,计算输出,还是直接上代码:

#!/usr/bin/env python
#-*- coding:utf-8 -*-
############################
#File Name: tut3_reuse.py
#Author: Wang
#Mail: wang19920419@hotmail.com
#Created Time:2017-08-30 14:33:35
############################ 

import tensorflow as tf 

sess = tf.Session() 

# First, load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model.meta')
saver.restore(sess, tf.train.latest_checkpoint('./')) 

# Second, access and create placeholders variables and create feed_dict to feed new data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name('w1:0')
w2 = graph.get_tensor_by_name('w2:0')
feed_dict = {w1:[-1,1], w2:[4,6]} 

# Access the op that want to run
op_to_restore = graph.get_tensor_by_name('op_to_restore:0') 

print sess.run(op_to_restore, feed_dict)   # ouotput: [6. 14.]

在已经加载的网络后继续加入新的网络层:

import tensorflow as tf
sess=tf.Session()
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./')) 

# Now, let's access and create placeholders variables and
# create feed-dict to feed new data 

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0} 

#Now, access the op that you want to run.
op_to_restore = graph.get_tensor_by_name("op_to_restore:0") 

#Add more to the current graph
add_on_op = tf.multiply(op_to_restore,2) 

print sess.run(add_on_op,feed_dict)
#This will print 120.

对加载的网络进行局部修改和处理(这个最麻烦,我还没搞太明白,后续会继续补充):

......
......
saver = tf.train.import_meta_graph('vgg.meta')
# Access the graph
graph = tf.get_default_graph()
## Prepare the feed_dict for feeding data for fine-tuning  

#Access the appropriate output for fine-tuning
fc7= graph.get_tensor_by_name('fc7:0') 

#use this if you only want to change gradients of the last layer
fc7 = tf.stop_gradient(fc7) # It's an identity function
fc7_shape= fc7.get_shape().as_list() 

new_outputs=2
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))
output = tf.matmul(fc7, weights) + biases
pred = tf.nn.softmax(output) 

# Now, you run this with fine-tuning data in sess.run() 

有了这样的方法,无论是自行训练、加载模型继续训练、使用经典模型还是finetune经典模型抑或是加载网络跑前项,效果都是杠杠的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持我们。

您可能感兴趣的文章:

  • Win7下Python与Tensorflow-CPU版开发环境的安装与配置过程
  • python tensorflow基于cnn实现手写数字识别
(0)

相关推荐

  • Win7下Python与Tensorflow-CPU版开发环境的安装与配置过程

    以此文记录Python与Tensorflow及其开发环境的安装与配置过程,以备以后参考. 1 硬件与系统条件 Win7 64位系统,显卡为NVIDIA GeforeGT 635M 2 安装策略 a.由于以上原因,选择在win7下安装cpu版的tensorflow,使用anconda安装,总结下来,这么做是代价最小的. b. 首先,不要急于下载Python,因为最新的版本可能会与Anaconda中的Python版本发生冲突.以目前(截止2017-06-17日)的情况,Anaconda选择Anaco

  • python tensorflow基于cnn实现手写数字识别

    一份基于cnn的手写数字自识别的代码,供大家参考,具体内容如下 # -*- coding: utf-8 -*- import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # 加载数据集 mnist = input_data.read_data_sets('MNIST_data', one_hot=True) # 以交互式方式启动session # 如果不使用交互式session,则在启动s

  • python使用tensorflow保存、加载和使用模型的方法

    使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用.介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍: http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/ 我对这篇文章进行了整理和汇总. 首先是模型的保存.直接上代码: #!/usr/bin/env python #-*- c

  • Python 保存加载mat格式文件的示例代码

    mat为matlab常用存储数据的文件格式,python的scipy.io模块中包含保存和加载mat格式文件的API,使用极其简单,不再赘述:另附简易示例如下: # -*- coding: utf-8 -*- import numpy as np import scipy.io as scio # data data = np.array([1,2,3]) data2 = np.array([4,5,6]) # save mat (data format: dict) scio.savemat(

  • python用pandas数据加载、存储与文件格式的实例

    数据加载.存储与文件格式 pandas提供了一些用于将表格型数据读取为DataFrame对象的函数.其中read_csv和read_talbe用得最多 pandas中的解析函数: 函数 说明 read_csv 从文件.URL.文件型对象中加载带分隔符的数据,默认分隔符为逗号 read_table 从文件.URL.文件型对象中加载带分隔符的数据.默认分隔符为制表符("\t") read_fwf 读取定宽列格式数据(也就是说,没有分隔符) read_clipboard 读取剪贴板中的数据,

  • python爬虫中PhantomJS加载页面的实例方法

    PhantomJS作为常用获取页面的工具之一,我们已经讲过页面测试.代码评估和捕获屏幕这几种使用的方式.当然最厉害的还是网页方面的捕捉,这里就不再讲述了.今天我们要讲的是它加载页面的新方法,这个可能很多人不知道.其实经常会用到,感兴趣的小伙伴一起进入今天的学习之中吧~ 可以利用 phantom 来实现页面的加载,下面的例子实现了页面的加载并将页面保存为一张图片. var page = require('webpage').create();page.open('http://cuiqingcai

  • pycharm远程连接服务器调试tensorflow无法加载问题

    最近打算在win系统下使用pycharm开发程序,并远程连接服务器调试程序,其中在import tensorflow时报错如图所示(在远程服务器中执行程序正常): 直观错误为: ImportError: libcusolver.so.8.0: cannot open shared object file: No such file or directory Failed to load the native TensorFlow runtime. 原因为无法加载libcusolver.so等,查

  • 基于Python实现配置热加载的方法详解

    目录 背景 如何实现 使用多进程实现配置热加载 使用signal信号量来实现热加载 采用multiprocessing.Event 来实现配置热加载 结语 背景 由于最近工作需求,需要在已有项目添加一个新功能,实现配置热加载的功能.所谓的配置热加载,也就是说当服务收到配置更新消息之后,我们不用重启服务就可以使用最新的配置去执行任务. 如何实现 下面我分别采用多进程.多线程.协程的方式去实现配置热加载. 使用多进程实现配置热加载 如果我们代码实现上使用多进程, 主进程1来更新配置并发送指令,任务的

  • python web基础之加载静态文件实例

    在web运行中很重要的一个功能就是加载静态文件,在django中可能已经给我们设置好了,我们只要直接把模板文件 放在templates就好了,但是你知道在基础中,像图片是怎么加载以及找到相应位置的吗? 下面我们来看看. 在上篇文章中我把,静态文件的路径单独出来在这里说说了,正好说说全局变量request的作用. 首先,我们写前端图片的路径: <img src="/static?file=1.gif"/> 看到这里,可能已经有人看出来了,对的,我们把图片路径看成url路径和参

  • Python pycharm 同时加载多个项目的方法

    在pycharm中只能一个项目存在,想打开另一个项目只能建一个新窗口或者把当前窗口覆盖掉. 在pycharm中其实可以同时打开多个项目: 1.file->setting->project 2.选择project structure,在窗口右侧的add content root 中添加要显示的项目 以上这篇Python pycharm 同时加载多个项目的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们.

  • TensorFlow获取加载模型中的全部张量名称代码

    核心代码如下: [tensor.name for tensor in tf.get_default_graph().as_graph_def().node] 实例代码:(加载了Inceptino_v3的模型,并获取该模型所有节点的名称) # -*- coding: utf-8 -*- import tensorflow as tf import os model_dir = 'C:/Inception_v3' model_name = 'output_graph.pb' # 读取并创建一个图gr

  • Python pkg_resources模块动态加载插件实例分析

    使用标准库importlib的import_module()函数.django的import_string(),它们都可以动态加载指定的 Python 模块. 举两个动态加载例子: 举例一: 在你项目中有个test函数,位于your_project/demo/test.py中,那么你可以使用import_module来动态加载并调用这个函数而不需要在使用的地方通过import导入. module_path = 'your_project/demo' module = import_module(

随机推荐