tensorflow实现测试时读取任意指定的check point的网络参数

tensorflow在训练时会保存三个文件,

model.ckpt-xxx.data-00000-of-00001
model.ckpt-xxx.index
model.ckpt-xxx.meta

其中第一个储存网络参数值,第二个储存每一层的名字,第三个储存图结构

随着训练的过程,每隔一段时间都会保存一组以上三个文件,而在训练之前我们并不知道什么时候可以达到最佳的拟合,训练时间过短会导致欠拟合,训练时间过长则会导致过拟合。

如果每次测试时,我们都自动调用最新一次的check point,那很可能不是最佳的一组参数,当我们训练了很多个epoch时,我们需要往回寻找最佳的check point,此时就需要指定的check point,下面有是具体方法:

修改checkpoint文件

一个checkpoint文件的内容如下

model_checkpoint_path: "model.ckpt-1623"
all_model_checkpoint_paths: "model.ckpt-1393"
all_model_checkpoint_paths: "model.ckpt-1451"
all_model_checkpoint_paths: "model.ckpt-1507"
all_model_checkpoint_paths: "model.ckpt-1565"
all_model_checkpoint_paths: "model.ckpt-1623"

这里面的后缀不同的数字就是不同的版本的参数,数字越小越早,系统会自动默认最新的训练出来的参数,而我们只需要在第一行把数字修改为我们想要调用的ckpt即可。

以上这篇tensorflow实现测试时读取任意指定的check point的网络参数就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • 在tensorflow中设置保存checkpoint的最大数量实例

    1.我就废话不多说了,直接上代码吧! # Set up a RunConfig to only save checkpoints once per training cycle. run_config = tf.estimator.RunConfig(save_checkpoints_secs=1e9,keep_checkpoint_max = 10) model = tf.estimator.Estimator( model_fn=deeplab_model_focal_class_imbal

  • tensorflow 固定部分参数训练,只训练部分参数的实例

    在使用tensorflow来训练一个模型的时候,有时候需要依靠验证集来判断模型是否已经过拟合,是否需要停止训练. 1.首先想到的是用tf.placeholder()载入不同的数据来进行计算,比如 def inference(input_): """ this is where you put your graph. the following is just an example. """ conv1 = tf.layers.conv2d(inp

  • 解决Pycharm的项目目录突然消失的问题

    今天在玩pycharm的时候不知道按了其中什么按钮,然后我们的项目目录全部都不见了(一开始还不知道这个叫做项目目录)然后自己捣鼓了好久各个窗口的打开关闭,终于最后被我发现了什么- 1. pycharm的项目全部不见了 自己作死不知道按了什么按钮,然后我们的项目目录变成这样了,对于有点强迫症的我们来说实在是太难受了.点击子文件还得一个一个找. 2. 问题的出现的原因 其实我们应该是按了project->mark directory as->exclude 然后就变成这样子的结果了. 3. 解决之

  • TensorFlow——Checkpoint为模型添加检查点的实例

    1.检查点 保存模型并不限于在训练模型后,在训练模型之中也需要保存,因为TensorFlow训练模型时难免会出现中断的情况,我们自然希望能够将训练得到的参数保存下来,否则下次又要重新训练. 这种在训练中保存模型,习惯上称之为保存检查点. 2.添加保存点 通过添加检查点,可以生成载入检查点文件,并能够指定生成检查文件的个数,例如使用saver的另一个参数--max_to_keep=1,表明最多只保存一个检查点文件,在保存时使用如下的代码传入迭代次数. import tensorflow as tf

  • tensorflow实现测试时读取任意指定的check point的网络参数

    tensorflow在训练时会保存三个文件, model.ckpt-xxx.data-00000-of-00001 model.ckpt-xxx.index model.ckpt-xxx.meta 其中第一个储存网络参数值,第二个储存每一层的名字,第三个储存图结构 随着训练的过程,每隔一段时间都会保存一组以上三个文件,而在训练之前我们并不知道什么时候可以达到最佳的拟合,训练时间过短会导致欠拟合,训练时间过长则会导致过拟合. 如果每次测试时,我们都自动调用最新一次的check point,那很可能

  • tensorflow实现从.ckpt文件中读取任意变量

    思路有些混乱,希望大家能理解我的意思. 看了faster rcnn的tensorflow代码,关于fix_variables的作用我不是很明白,所以写了以下代码,读取了预训练模型vgg16得fc6和fc7的参数,以及faster rcnn中heat_to_tail中的fc6和fc7,将它们做了对比,发现结果不一样,说明vgg16的fc6和fc7只是初始化了faster rcnn中heat_to_tail中的fc6和fc7,之后后者被训练. 具体读取任意变量的代码如下: import tensor

  • SpringBoot 在测试时如何指定包的扫描范围

    目录 在测试时如何指定包的扫描范围 以往-这么写 通过@SpringBootApplication注解配置容器的包扫描范围 配置扫描包范围 如何修改包扫描的位置? 方法一 方法二 在测试时如何指定包的扫描范围 @SpringBootTest注解,在SpringBoot在启动会根据主启动类上的@SpringBootApplication去扫描当前类及其子包下的类.当出现子包中相同类名时,容器失败. 可以通过为相同的类指定不同的ID解决,也可以通过在SpringBoot测试时指容器的包扫描范围解决.

  • 利用Tensorflow的队列多线程读取数据方式

    在tensorflow中,有三种方式输入数据 1. 利用feed_dict送入numpy数组 2. 利用队列从文件中直接读取数据 3. 预加载数据 其中第一种方式很常用,在tensorflow的MNIST训练源码中可以看到,通过feed_dict={},可以将任意数据送入tensor中. 第二种方式相比于第一种,速度更快,可以利用多线程的优势把数据送入队列,再以batch的方式出队,并且在这个过程中可以很方便地对图像进行随机裁剪.翻转.改变对比度等预处理,同时可以选择是否对数据随机打乱,可以说是

  • python3读取文件指定行的三种方法

    行遍历实现 在python中如果要将一个文件完全加载到内存中,通过file.readlines()即可,但是在文件占用较高时,我们是无法完整的将文件加载到内存中的,这时候就需要用到python的file.readline()进行迭代式的逐行读取: filename = 'hello.txt' with open(filename, 'r') as file: line = file.readline() counts = 1 while line: if counts >= 50000000:

  • SpringBoot测试时卡在Resolving Maven dependencies的问题

    目录 测试时卡在Resolving Maven dependencies Maven项目缺少Maven Dependencies问题 今天搭建了一个maven项目 网上其他解决Maven Dependencies文件缺失的方法 dependencyManagement与dependencies的区别 测试时卡在Resolving Maven dependencies 有没有遇到这个问题,在测试的时候 一直卡在Resolving Maven dependencies… 框内其实因为一直下载一个Ju

  • 如何设置Spring Boot测试时的日志级别

    1.概览 该教程中,我将向你展示:如何在测试时设置spring boot 日志级别.虽然我们可以在测试通过时忽略日志,但是如果需要诊断失败的测试,选择正确的日志级别是非常重要的. 2.日志级别的重要性 正确设置日志级别可以节省我们许多时间. 举例来说,如果测试在CI服务器上失败,但在开发服务器上时却通过了.我们将无法诊断失败的测试,除非有足够的日志输出. 为了获取正确数量的详细信息,我们可以微调应用程序的日志级别,如果发现某个java包对我们的测试更加重要,可以给它一个更低的日志级别,比如DEB

  • tensorflow使用range_input_producer多线程读取数据实例

    先放关键代码: i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue() inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE]) 原理解析: 第一行会产生一个队列,队列包含0到NUM_EXPOCHES-1的元素,如果num_epochs有指定,则每个元素只产生num_epochs次,否则循环产生.shuffle指定是否打乱顺

  • TensorFlow Saver:保存和读取模型参数.ckpt实例

    在使用TensorFlow的过程中,保存模型参数变量是很重要的一个环节,既可以保证训练过程信息不丢失,也可以帮助我们在需要快速恢复或使用一个模型的时候,利用之前保存好的参数之间导入,可以节省大量的训练时间.本文通过最简单的例程教大家如何保存和读取.ckpt文件. 一.保存到文件 首先是导入必要的东西: import tensorflow as tf import numpy as np 随便写几个变量: # Save to file # remember to define the same d

  • 浅谈tensorflow使用张量时的一些注意点tf.concat,tf.reshape,tf.stack

    有一段时间没用tensorflow了,现在跑实验还是存在一些坑了,主要是关于张量计算的问题.tensorflow升级1.0版本后与以前的版本并不兼容,可能出现各种奇奇怪怪的问题. 1 tf.concat函数 tensorflow1.0以前函数用法:tf.concat(concat_dim, values, name='concat'),第一个参数为连接的维度,可以将几个向量按指定维度连接起来. 如: t1 = [[1, 2, 3], [4, 5, 6]] t2 = [[7, 8, 9], [10

随机推荐