对tensorflow 的模型保存和调用实例讲解

我们通常采用tensorflow来训练,训练完之后应当保存模型,即保存模型的记忆(权重和偏置),这样就可以来进行人脸识别或语音识别了。

1.模型的保存

# 声明两个变量
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
init_op = tf.global_variables_initializer() # 初始化全部变量
saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:
 sess.run(init_op)
 print("v1:", sess.run(v1)) # 打印v1、v2的值一会读取之后对比
 print("v2:", sess.run(v2))
  #定义保存路径,一定要是绝对路径,且用‘/ '分隔父目录与子目录
 saver_path = saver.save(sess, "C:/Users/Administrator/Desktop/tt/model.ckpt") # 将模型保存到save/model.ckpt文件
 print("Model saved in file:", saver_path)

2.模型的读取

直接读取模型时,可能会报错,我是用Spyder编译的,可以把Spyder关掉,再重新打开,就可以读取数据了。原因可能是:在模型保存时将变量初始化了。

import tensorflow as tf

# 使用和保存模型代码中一样的方式来声明变量
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:
 saver.restore(sess, "C:/Users/Administrator/Desktop/tt/model.ckpt") # 即将固化到硬盘中的Session从保存路径再读取出来
 print("v1:", sess.run(v1)) # 打印v1、v2的值和之前的进行对比
 print("v2:", sess.run(v2))
 print("Model Restored")

以上这篇对tensorflow 的模型保存和调用实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • 解决Django中调用keras的模型出现的问题

    笔者小白在用Django写一个表格单据图片的识别应用的时候,遇到了调用基于Tensorflow的keras模型出错的问题. 出现的错误信息类似于以下: ValueError: Tensor Tensor("Placeholder:0", shape=(3, 3, 1, 32), dtype=float32) 通过查询相关的资料,对解决的方式做一个记录. 方法1.通过导入 import Keras 然后在构建模型前面加一句 keras.backend.clear_session() 方法

  • python 用opencv调用训练好的模型进行识别的方法

    此程序为先调用opencv自带的人脸检测模型,检测到人脸后,再调用我自己训练好的模型去识别人脸,使用时更改模型地址即可 #!usr/bin/env python import cv2 font=cv2.FONT_HERSHEY_SIMPLEX cascade1 = cv2.CascadeClassifier("D:\\opencv249\\opencv\\sources\\data\\haarcascades\\haarcascade_frontalface_alt_tree.xml"

  • Python时间序列处理之ARIMA模型的使用讲解

    ARIMA模型 ARIMA模型的全称是自回归移动平均模型,是用来预测时间序列的一种常用的统计模型,一般记作ARIMA(p,d,q). ARIMA的适应情况 ARIMA模型相对来说比较简单易用.在应用ARIMA模型时,要保证以下几点: 时间序列数据是相对稳定的,总体基本不存在一定的上升或者下降趋势,如果不稳定可以通过差分的方式来使其变稳定. 非线性关系处理不好,只能处理线性关系 判断时序数据稳定 基本判断方法:稳定的数据,总体上是没有上升和下降的趋势的,是没有周期性的,方差趋向于一个稳定的值. A

  • 对YOLOv3模型调用时候的python接口详解

    需要注意的是:更改完源程序.c文件,需要对整个项目重新编译.make install,对已经生成的文件进行更新,类似于之前VS中在一个类中增加新函数重新编译封装dll,而python接口的调用主要使用的是libdarknet.so文件,其余在配置文件中的修改不必重新进行编译安装. 之前训练好的模型,在模型调用的时候,总是在 lib = CDLL("/home/*****/*******/darknet/libdarknet.so", RTLD_GLOBAL)这里读不到darknet编译

  • python使用tensorflow保存、加载和使用模型的方法

    使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用.介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍: http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/ 我对这篇文章进行了整理和汇总. 首先是模型的保存.直接上代码: #!/usr/bin/env python #-*- c

  • 对tensorflow 的模型保存和调用实例讲解

    我们通常采用tensorflow来训练,训练完之后应当保存模型,即保存模型的记忆(权重和偏置),这样就可以来进行人脸识别或语音识别了. 1.模型的保存 # 声明两个变量 v1 = tf.Variable(tf.random_normal([1, 2]), name="v1") v2 = tf.Variable(tf.random_normal([2, 3]), name="v2") init_op = tf.global_variables_initializer(

  • tensorflow 加载部分变量的实例讲解

    tensorflow模型保存为saver = tf.train.Saver()函数,saver.save()保存模型,代码如下: import tensorflow as tf v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1") v2= tf.Variable(tf.zeros([200]), name="v2") saver = tf.train.Saver() with tf

  • spring和quartz整合,并简单调用(实例讲解)

    工作中会定时任务~简单学习一下. 第0步: 工欲善其事必先利其器,首先要做的自然是导包了. 在spring配置包扫描以及在 pom导入包 spring.xml: pom.xml 1.在spring-quartz.xml(和spring.xml同一个位置)配置相关属性 xml的头部每个人都可能不一样,这个自己要用的时候注意. quartz表达式根据自己需求去写,不列举了,这里的是1秒一次的. 2.Task包下配置类 我们这边将定时任务存放到一个包中,命名为task.用spring的自动注解serv

  • python Task在协程调用实例讲解

    1.说明 Tasks用于并发调度协程,通过asyncio.create_task(协程对象)创建Task对象,使协程能够加入事件循环,等待调度执行.除使用asyncio.create_task()函数外,还可使用低级loop.create_task()或ensure_future()函数.推荐使用手动实例Task对象. 2.使用注意 Python3.7中添加到asyncio.create_task函数.在Python3.7之前,可以使用低级asyncio.ensure_future函数. 3.实

  • PHP+MySQL+jQuery随意拖动层并即时保存拖动位置实例讲解

    想拖动页面上的层,完全可以用jQuery ui的Draggable方法来实现,那如何将拖动后层的位置保存下来呢?本文将给出答案.本文讲解了如何采用PHP+MySQL+jQuery,实现随意拖动层并即时保存拖动位置. 本文原理就是通过拖动将拖动后层的相对位置left,top和z-index三个参数更新到数据表中对应的记录,页面通过CSS解析每个层不同的位置.请看具体实现步骤. 准备MySQL数据表 首先需要准备一张表notes,用来记录层的内容,背景色和坐标等信息. CREATE TABLE IF

  • 浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式

    我们经常会看到后缀名为.pt, .pth, .pkl的pytorch模型文件,这几种模型文件在格式上有什么区别吗? 其实它们并不是在格式上有区别,只是后缀不同而已(仅此而已),在用torch.save()函数保存模型文件时,各人有不同的喜好,有些人喜欢用.pt后缀,有些人喜欢用.pth或.pkl.用相同的torch.save()语句保存出来的模型文件没有什么不同. 在pytorch官方的文档/代码里,有用.pt的,也有用.pth的.一般惯例是使用.pth,但是官方文档里貌似.pt更多,而且官方也

  • tensorflow模型保存、加载之变量重命名实例

    话不多说,干就完了. 变量重命名的用处? 简单定义:简单来说就是将模型A中的参数parameter_A赋给模型B中的parameter_B 使用场景:当需要使用已经训练好的模型参数,尤其是使用别人训练好的模型参数时,往往别人模型中的参数命名方式与自己当前的命名方式不同,所以在加载模型参数时需要对参数进行重命名,使得代码更简洁易懂. 实现方法: 1).模型保存 import os import tensorflow as tf weights = tf.Variable(initial_value

  • Tensorflow 使用pb文件保存(恢复)模型计算图和参数实例详解

    一.保存: graph_util.convert_variables_to_constants 可以把当前session的计算图串行化成一个字节流(二进制),这个函数包含三个参数:参数1:当前活动的session,它含有各变量 参数2:GraphDef 对象,它描述了计算网络 参数3:Graph图中需要输出的节点的名称的列表 返回值:精简版的GraphDef 对象,包含了原始输入GraphDef和session的网络和变量信息,它的成员函数SerializeToString()可以把这些信息串行

  • TensorFlow利用saver保存和提取参数的实例

    在训练循环中,定期调用 saver.save() 方法,向文件夹中写入包含了当前模型中所有可训练变量的 checkpoint 文件. saver.save(sess, FLAGS.train_dir, global_step=step) global_step是训练的第几步 保存参数: import tensorflow as tf W = tf.Variable([[1, 2, 3]], dtype=tf.float32) b = tf.Variable([[1]], dtype=tf.flo

  • TensorFlow模型保存和提取的方法

    一.TensorFlow模型保存和提取方法 1. TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提取.tf.train.Saver对象saver的save方法将TensorFlow模型保存到指定路径中,saver.save(sess,"Model/model.ckpt") ,实际在这个文件目录下会生成4个人文件: checkpoint文件保存了一个录下多有的模型文件列表,model.ckpt.meta保存了TensorFlow计算图的结构信息,model

随机推荐