在tensorflow中设置保存checkpoint的最大数量实例

1、我就废话不多说了,直接上代码吧!

 # Set up a RunConfig to only save checkpoints once per training cycle.
 run_config = tf.estimator.RunConfig(save_checkpoints_secs=1e9,keep_checkpoint_max = 10)
 model = tf.estimator.Estimator(
   model_fn=deeplab_model_focal_class_imbalance_loss_adaptive.deeplabv3_plus_model_fn,
   model_dir=FLAGS.model_dir,
   config=run_config,
   params={
     'output_stride': FLAGS.output_stride,
     'batch_size': FLAGS.batch_size,
     'base_architecture': FLAGS.base_architecture,
     'pre_trained_model': FLAGS.pre_trained_model,
     'batch_norm_decay': _BATCH_NORM_DECAY,
     'num_classes': _NUM_CLASSES,
     'tensorboard_images_max_outputs': FLAGS.tensorboard_images_max_outputs,
     'weight_decay': FLAGS.weight_decay,
     'learning_rate_policy': FLAGS.learning_rate_policy,
     'num_train': _NUM_IMAGES['train'],
     'initial_learning_rate': FLAGS.initial_learning_rate,
     'max_iter': FLAGS.max_iter,
     'end_learning_rate': FLAGS.end_learning_rate,
     'power': _POWER,
     'momentum': _MOMENTUM,
     'freeze_batch_norm': FLAGS.freeze_batch_norm,
     'initial_global_step': FLAGS.initial_global_step
   })

以上这篇在tensorflow中设置保存checkpoint的最大数量实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • 解决Pycharm的项目目录突然消失的问题

    今天在玩pycharm的时候不知道按了其中什么按钮,然后我们的项目目录全部都不见了(一开始还不知道这个叫做项目目录)然后自己捣鼓了好久各个窗口的打开关闭,终于最后被我发现了什么- 1. pycharm的项目全部不见了 自己作死不知道按了什么按钮,然后我们的项目目录变成这样了,对于有点强迫症的我们来说实在是太难受了.点击子文件还得一个一个找. 2. 问题的出现的原因 其实我们应该是按了project->mark directory as->exclude 然后就变成这样子的结果了. 3. 解决之

  • tensorflow 固定部分参数训练,只训练部分参数的实例

    在使用tensorflow来训练一个模型的时候,有时候需要依靠验证集来判断模型是否已经过拟合,是否需要停止训练. 1.首先想到的是用tf.placeholder()载入不同的数据来进行计算,比如 def inference(input_): """ this is where you put your graph. the following is just an example. """ conv1 = tf.layers.conv2d(inp

  • TensorFlow——Checkpoint为模型添加检查点的实例

    1.检查点 保存模型并不限于在训练模型后,在训练模型之中也需要保存,因为TensorFlow训练模型时难免会出现中断的情况,我们自然希望能够将训练得到的参数保存下来,否则下次又要重新训练. 这种在训练中保存模型,习惯上称之为保存检查点. 2.添加保存点 通过添加检查点,可以生成载入检查点文件,并能够指定生成检查文件的个数,例如使用saver的另一个参数--max_to_keep=1,表明最多只保存一个检查点文件,在保存时使用如下的代码传入迭代次数. import tensorflow as tf

  • tensorflow实现测试时读取任意指定的check point的网络参数

    tensorflow在训练时会保存三个文件, model.ckpt-xxx.data-00000-of-00001 model.ckpt-xxx.index model.ckpt-xxx.meta 其中第一个储存网络参数值,第二个储存每一层的名字,第三个储存图结构 随着训练的过程,每隔一段时间都会保存一组以上三个文件,而在训练之前我们并不知道什么时候可以达到最佳的拟合,训练时间过短会导致欠拟合,训练时间过长则会导致过拟合. 如果每次测试时,我们都自动调用最新一次的check point,那很可能

  • 在tensorflow中设置保存checkpoint的最大数量实例

    1.我就废话不多说了,直接上代码吧! # Set up a RunConfig to only save checkpoints once per training cycle. run_config = tf.estimator.RunConfig(save_checkpoints_secs=1e9,keep_checkpoint_max = 10) model = tf.estimator.Estimator( model_fn=deeplab_model_focal_class_imbal

  • 在tensorflow中设置使用某一块GPU、多GPU、CPU的操作

    tensorflow下设置使用某一块GPU(从0开始编号): import os os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "1" 多GPU: num_gpus = 4 for i in range(num_gpus): with tf.device('/gpu:%d',%i): ... 只是用cpu的

  • TensorFlow利用saver保存和提取参数的实例

    在训练循环中,定期调用 saver.save() 方法,向文件夹中写入包含了当前模型中所有可训练变量的 checkpoint 文件. saver.save(sess, FLAGS.train_dir, global_step=step) global_step是训练的第几步 保存参数: import tensorflow as tf W = tf.Variable([[1, 2, 3]], dtype=tf.float32) b = tf.Variable([[1]], dtype=tf.flo

  • tensorflow实现训练变量checkpoint的保存与读取

    1.保存变量 先创建(在tf.Session()之前)saver saver = tf.train.Saver(tf.global_variables(),max_to_keep=1) #max_to_keep这个保证只保存最后一次training的训练数据 然后在训练的循环里面 checkpoint_path = os.path.join(Path, 'model.ckpt') saver.save(session, checkpoint_path, global_step=step) #这里

  • TensorFLow用Saver保存和恢复变量

    本文为大家分享了TensorFLow用Saver保存和恢复变量的具体代码,供大家参考,具体内容如下 建立文件tensor_save.py, 保存变量v1,v2的tensor到checkpoint files中,名称分别设置为v3,v4. import tensorflow as tf # Create some variables. v1 = tf.Variable(3, name="v1") v2 = tf.Variable(4, name="v2") # Cre

  • 浅谈Tensorflow模型的保存与恢复加载

    近期做了一些反垃圾的工作,除了使用常用的规则匹配过滤等手段,也采用了一些机器学习方法进行分类预测.我们使用TensorFlow进行模型的训练,训练好的模型需要保存,预测阶段我们需要将模型进行加载还原使用,这就涉及TensorFlow模型的保存与恢复加载. 总结一下Tensorflow常用的模型保存方式. 保存checkpoint模型文件(.ckpt) 首先,TensorFlow提供了一个非常方便的api,tf.train.Saver()来保存和还原一个机器学习模型. 模型保存 使用tf.trai

  • 对Tensorflow中权值和feature map的可视化详解

    前言 Tensorflow中可以使用tensorboard这个强大的工具对计算图.loss.网络参数等进行可视化.本文并不涉及对tensorboard使用的介绍,而是旨在说明如何通过代码对网络权值和feature map做更灵活的处理.显示和存储.本文的相关代码主要参考了github上的一个小项目,但是对其进行了改进. 原项目地址为(https://github.com/grishasergei/conviz). 本文将从以下两个方面进行介绍: 卷积知识补充 网络权值和feature map的可

  • 对TensorFlow中的variables_to_restore函数详解

    variables_to_restore函数,是TensorFlow为滑动平均值提供.之前,也介绍过通过使用滑动平均值可以让神经网络模型更加的健壮.我们也知道,其实在TensorFlow中,变量的滑动平均值都是由影子变量所维护的,如果你想要获取变量的滑动平均值需要获取的是影子变量而不是变量本身. 1.滑动平均值模型文件的保存 import tensorflow as tf if __name__ == "__main__": v = tf.Variable(0.,name="

  • 解决tensorflow模型参数保存和加载的问题

    终于找到bug原因!记一下:还是不熟悉平台的原因造成的! Q:为什么会出现两个模型对象在同一个文件中一起运行,当直接读取他们分开运行时训练出来的模型会出错,而且总是有一个正确,一个读取错误? 而 直接在同一个文件又训练又重新加载模型预测不出错,而且更诡异的是此时用分文件里的对象加载模型不会出错? model.py,里面含有 ModelV 和 ModelP,另外还有 modelP.py 和 modelV.py 分别只含有 ModelP 和 ModeV 这两个对象,先使用 modelP.py 和 m

  • tensorflow将图片保存为tfrecord和tfrecord的读取方式

    tensorflow官方提供了3种方法来读取数据: 预加载数据(preloaded data):在TensorFlow图中定义常量或变量来保存所有的数据,适用于数据量不太大的情况.填充数据(feeding):通过Python产生数据,然后再把数据填充到后端. 从文件读取数据(reading from file):从文件中直接读取,然后通过队列管理器从文件中读取数据. 本文主要介绍第三种方法,通过tfrecord文件来保存和读取数据,对于前两种读取数据的方式也会进行一个简单的介绍. 项目下载git

随机推荐