tensorflow实现训练变量checkpoint的保存与读取
1.保存变量
先创建(在tf.Session()之前)saver
saver = tf.train.Saver(tf.global_variables(),max_to_keep=1) #max_to_keep这个保证只保存最后一次training的训练数据
然后在训练的循环里面
checkpoint_path = os.path.join(Path, 'model.ckpt') saver.save(session, checkpoint_path, global_step=step) #这里的step是循环训练的次数,也就是第几次迭代
以下保存的变量文件
2.变量读取
1.若要直接恢复所有变量可以
saver = tf.train.Saver(tf.global_variables()) moudke_file=tf.train.latest_checkpoint('PATH') saver.restore(sess,moudke_file)
PATH是存放保存变量的路径,会自动找到最近保存的变量文件
2 若想读取其中一部分变量值
def read_checkpoint(): w = [] checkpoint_path = '/home/ximao/models/resnet3/variable_logs/model.ckpt-17000' reader = tf.train.NewCheckpointReader(checkpoint_path) var = reader.get_variable_to_shape_map() for key in var: if 'weights' in key and 'conv' in key and 'Mo' not in key: print('tensorname:', key) # # print(reader.get_tensor(key))
3. 若想恢复其中一部分变量值到新网络
(1)首先你要先获取你想要赋值新网络变量的变量名,这里变量名不是一个字符串,而是<name,shape,dtype>这样的一个结构,
然后把你要赋值的元素转为张量,最后把值赋给你得到变量名 如下:
var=[v for v in weight_pruned if v.op.name=='WRN/conv1/weights'] conv1_temp=tf.convert_to_tensor(conv1,dtype=tf.float32) sess.run(tf.assign(var[0],conv1_temp))
weight_pruned 存放的是你新网络中所有的变量
以上这篇tensorflow实现训练变量checkpoint的保存与读取就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。
相关推荐
-
Tensorflow中使用tfrecord方式读取数据的方法
前言 本博客默认读者对神经网络与Tensorflow有一定了解,对其中的一些术语不再做具体解释.并且本博客主要以图片数据为例进行介绍,如有错误,敬请斧正. 使用Tensorflow训练神经网络时,我们可以用多种方式来读取自己的数据.如果数据集比较小,而且内存足够大,可以选择直接将所有数据读进内存,然后每次取一个batch的数据出来.如果数据较多,可以每次直接从硬盘中进行读取,不过这种方式的读取效率就比较低了.此篇博客就主要讲一下Tensorflow官方推荐的一种较为高效的数据读取方式--tfre
-
详解Tensorflow数据读取有三种方式(next_batch)
Tensorflow数据读取有三种方式: Preloaded data: 预加载数据 Feeding: Python产生数据,再把数据喂给后端. Reading from file: 从文件中直接读取 这三种有读取方式有什么区别呢? 我们首先要知道TensorFlow(TF)是怎么样工作的. TF的核心是用C++写的,这样的好处是运行快,缺点是调用不灵活.而Python恰好相反,所以结合两种语言的优势.涉及计算的核心算子和运行框架是用C++写的,并提供API给Python.Python调用这些A
-
TensorFlow实现从txt文件读取数据
TensorFlow从txt文件中读取数据的方法很多有种,我比较常用的是下面两种: [1]np.loadtxt import numpy as np data=np.loadtxt('ex1data1.txt',dtype='float',delimiter=',') X_train=data[:,0] y_train=data[:,1] [2]pd.read_csv import pandas as pd data=pd.read_csv("ex2data2.txt",names=[
-
利用Tensorflow的队列多线程读取数据方式
在tensorflow中,有三种方式输入数据 1. 利用feed_dict送入numpy数组 2. 利用队列从文件中直接读取数据 3. 预加载数据 其中第一种方式很常用,在tensorflow的MNIST训练源码中可以看到,通过feed_dict={},可以将任意数据送入tensor中. 第二种方式相比于第一种,速度更快,可以利用多线程的优势把数据送入队列,再以batch的方式出队,并且在这个过程中可以很方便地对图像进行随机裁剪.翻转.改变对比度等预处理,同时可以选择是否对数据随机打乱,可以说是
-
TensorFlow入门使用 tf.train.Saver()保存模型
关于模型保存的一点心得 saver = tf.train.Saver(max_to_keep=3) 在定义 saver 的时候一般会定义最多保存模型的数量,一般来说,如果模型本身很大,我们需要考虑到硬盘大小.如果你需要在当前训练好的模型的基础上进行 fine-tune,那么尽可能多的保存模型,后继 fine-tune 不一定从最好的 ckpt 进行,因为有可能一下子就过拟合了.但是如果保存太多,硬盘也有压力呀.如果只想保留最好的模型,方法就是每次迭代到一定步数就在验证集上计算一次 accurac
-
tensorflow实现读取模型中保存的值 tf.train.NewCheckpointReader
使用tf.trian.NewCheckpointReader(model_dir) 一个标准的模型文件有一下文件, model_dir就是MyModel(没有后缀) checkpoint Model.meta Model.data-00000-of-00001 Model.index import tensorflow as tf import pprint # 使用pprint 提高打印的可读性 NewCheck =tf.train.NewCheckpointReader("model&quo
-
tensorflow实现训练变量checkpoint的保存与读取
1.保存变量 先创建(在tf.Session()之前)saver saver = tf.train.Saver(tf.global_variables(),max_to_keep=1) #max_to_keep这个保证只保存最后一次training的训练数据 然后在训练的循环里面 checkpoint_path = os.path.join(Path, 'model.ckpt') saver.save(session, checkpoint_path, global_step=step) #这里
-
从训练好的tensorflow模型中打印训练变量实例
从tensorflow 训练后保存的模型中打印训变量:使用tf.train.NewCheckpointReader() import tensorflow as tf reader = tf.train.NewCheckpointReader('path/alexnet/model-330000') dic = reader.get_variable_to_shape_map() print dic 打印变量 w = reader.get_tensor("fc1/W") print t
-
Tensorflow加载预训练模型和保存模型的实例
使用tensorflow过程中,训练结束后我们需要用到模型文件.有时候,我们可能也需要用到别人训练好的模型,并在这个基础上再次训练.这时候我们需要掌握如何操作这些模型数据.看完本文,相信你一定会有收获! 1 Tensorflow模型文件 我们在checkpoint_dir目录下保存的文件结构如下: |--checkpoint_dir | |--checkpoint | |--MyModel.meta | |--MyModel.data-00000-of-00001 | |--MyModel.in
-
python深度学习TensorFlow神经网络模型的保存和读取
目录 之前的笔记里实现了softmax回归分类.简单的含有一个隐层的神经网络.卷积神经网络等等,但是这些代码在训练完成之后就直接退出了,并没有将训练得到的模型保存下来方便下次直接使用.为了让训练结果可以复用,需要将训练好的神经网络模型持久化,这就是这篇笔记里要写的东西. TensorFlow提供了一个非常简单的API,即tf.train.Saver类来保存和还原一个神经网络模型. 下面代码给出了保存TensorFlow模型的方法: import tensorflow as tf # 声明两个变量
-
TensorFlow Saver:保存和读取模型参数.ckpt实例
在使用TensorFlow的过程中,保存模型参数变量是很重要的一个环节,既可以保证训练过程信息不丢失,也可以帮助我们在需要快速恢复或使用一个模型的时候,利用之前保存好的参数之间导入,可以节省大量的训练时间.本文通过最简单的例程教大家如何保存和读取.ckpt文件. 一.保存到文件 首先是导入必要的东西: import tensorflow as tf import numpy as np 随便写几个变量: # Save to file # remember to define the same d
-
C#使用TensorFlow.NET训练自己的数据集的方法
今天,我结合代码来详细介绍如何使用 SciSharp STACK 的 TensorFlow.NET 来训练CNN模型,该模型主要实现 图像的分类 ,可以直接移植该代码在 CPU 或 GPU 下使用,并针对你们自己本地的图像数据集进行训练和推理.TensorFlow.NET是基于 .NET Standard 框架的完整实现的TensorFlow,可以支持 .NET Framework 或 .NET CORE , TensorFlow.NET 为广大.NET开发者提供了完美的机器学习框架选择. Sc
-
python神经网络tensorflow利用训练好的模型进行预测
目录 学习前言 载入模型思路 实现代码 学习前言 在神经网络学习中slim常用函数与如何训练.保存模型文章里已经讲述了如何使用slim训练出来一个模型,这篇文章将会讲述如何预测. 载入模型思路 载入模型的过程主要分为以下四步: 1.建立会话Session: 2.将img_input的placeholder传入网络,建立网络结构: 3.初始化所有变量: 4.利用saver对象restore载入所有参数. 这里要注意的重点是,在利用saver对象restore载入所有参数之前,必须要建立网络结构,因
-
对Tensorflow中的变量初始化函数详解
Tensorflow 提供了7种不同的初始化函数: tf.constant_initializer(value) #将变量初始化为给定的常量,初始化一切所提供的值. 假设在卷积层中,设置偏执项b为0,则写法为: 1. bias_initializer=tf.constant_initializer(0) 2. bias_initializer=tf.zeros_initializer(0) tf.random_normal_initializer(mean,stddev) #功能是将变量初始化为
-
Tensorflow之MNIST CNN实现并保存、加载模型
本文实例为大家分享了Tensorflow之MNIST CNN实现并保存.加载模型的具体代码,供大家参考,具体内容如下 废话不说,直接上代码 # TensorFlow and tf.keras import tensorflow as tf from tensorflow import keras # Helper libraries import numpy as np import matplotlib.pyplot as plt import os #download the data mn
-
SpringCloud使用Nacos保存和读取变量的配置方法
目录 前提条件 启动配置管理 注入配置 同步配置 注意: 在使用SpringCloud开发微服务时,经常会遇到一些比较小的后台参数配置,这些配置不足以单独开一张表去存储,而且其他服务会读取该参数.比如IP白名单.这时,使用Nacos去保存和读取就比较方便. 前提条件 使用SpringCloud的项目 启动Nacos 启动配置管理 添加依赖: <dependency> <groupId>com.alibaba.cloud</groupId> <artifactId&
随机推荐
- IE奥秘——添加新菜单项(推荐)
- PHP 二维数组根据某个字段排序的具体实现
- 如何阻止网站被恶意反向代理访问(防网站镜像)
- RecyclerView上拉加载封装代码
- bootstrap jquery dataTable 异步ajax刷新表格数据的实现方法
- MySQL 触发器详解及简单实例
- js实现首屏延迟加载实现方法 js实现多屏单张图片延迟加载效果
- 让低版本浏览器支持input的placeholder属性(js方法)
- 无刷新预览所选择的图片示例代码
- JavaScript中双向数据绑定详解
- Java采用setAsciiStream方法检索数据库指定内容实例解析
- Java多线程回调方法实例解析
- python+matplotlib演示电偶极子实例代码
- Java常用工具类 UUID、Map工具类
- ZooKeeper 实现分布式锁的方法示例
- oracle数据库实现获取时间戳的无参函数
- vue-cli 3.0 自定义vue.config.js文件,多页构建的方法
- python pandas中对Series数据进行轴向连接的实例
- java实现图片滑动验证(包含前端代码)
- Python爬虫之Spider类用法简单介绍