tensorflow 只恢复部分模型参数的实例

我就废话不多说了,直接上代码吧!

import tensorflow as tf

def model_1():
  with tf.variable_scope("var_a"):
    a = tf.Variable(initial_value=[1, 2, 3], name="a")

  vars = [var for var in tf.trainable_variables() if var.name.startswith("var_a")]
  print(len(vars))
  return vars

def model_2():

  vars1 = model_1()

  with tf.variable_scope("var_b"):
    a = tf.Variable(initial_value=[1, 2, 3], name="a")

  vars2 = [var for var in tf.trainable_variables() if var.name.startswith("var")]
  print(len(vars2))
  return vars1

def pretrain_model1():
  print("-------- model 1 ------")
  vars = model_1()

  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.save(sess, "./model.ckpt")

def train_model2():
  print("-------- model 2 ------")

  model1_vars = model_2()

  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver(var_list=model1_vars)
    saver.restore(sess, "./model.ckpt")
    vars = sess.run([model1_vars])
    for var in vars:
      print(var)

step = 2
if step == 1:
  pretrain_model1()
else:
  train_model2()

以上这篇tensorflow 只恢复部分模型参数的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • tensorflow 获取模型所有参数总和数量的方法

    实例如下所示: from functools import reduce from operator import mul def get_num_params(): num_params = 0 for variable in tf.trainable_variables(): shape = variable.get_shape() num_params += reduce(mul, [dim.value for dim in shape], 1) return num_params 以上这

  • 关于tensorflow的几种参数初始化方法小结

    在tensorflow中,经常会遇到参数初始化问题,比如在训练自己的词向量时,需要对原始的embeddigs矩阵进行初始化,更一般的,在全连接神经网络中,每层的权值w也需要进行初始化. tensorlfow中应该有一下几种初始化方法 1. tf.constant_initializer() 常数初始化 2. tf.ones_initializer() 全1初始化 3. tf.zeros_initializer() 全0初始化 4. tf.random_uniform_initializer()

  • 解决tensorflow模型参数保存和加载的问题

    终于找到bug原因!记一下:还是不熟悉平台的原因造成的! Q:为什么会出现两个模型对象在同一个文件中一起运行,当直接读取他们分开运行时训练出来的模型会出错,而且总是有一个正确,一个读取错误? 而 直接在同一个文件又训练又重新加载模型预测不出错,而且更诡异的是此时用分文件里的对象加载模型不会出错? model.py,里面含有 ModelV 和 ModelP,另外还有 modelP.py 和 modelV.py 分别只含有 ModelP 和 ModeV 这两个对象,先使用 modelP.py 和 m

  • tensorflow 只恢复部分模型参数的实例

    我就废话不多说了,直接上代码吧! import tensorflow as tf def model_1(): with tf.variable_scope("var_a"): a = tf.Variable(initial_value=[1, 2, 3], name="a") vars = [var for var in tf.trainable_variables() if var.name.startswith("var_a")] prin

  • TensorFlow Saver:保存和读取模型参数.ckpt实例

    在使用TensorFlow的过程中,保存模型参数变量是很重要的一个环节,既可以保证训练过程信息不丢失,也可以帮助我们在需要快速恢复或使用一个模型的时候,利用之前保存好的参数之间导入,可以节省大量的训练时间.本文通过最简单的例程教大家如何保存和读取.ckpt文件. 一.保存到文件 首先是导入必要的东西: import tensorflow as tf import numpy as np 随便写几个变量: # Save to file # remember to define the same d

  • Pytorch中实现只导入部分模型参数的方式

    我们在做迁移学习,或者在分割,检测等任务想使用预训练好的模型,同时又有自己修改之后的结构,使得模型文件保存的参数,有一部分是不需要的(don't expected).我们搭建的网络对保存文件来说,有一部分参数也是没有的(missed).如果依旧使用torch.load(model.state_dict())的办法,就会出现 xxx expected,xxx missed类似的错误.那么在这种情况下,该如何导入模型呢? 好在Pytorch中的模型参数使用字典保存的,键是参数的名称,值是参数的具体数

  • tensorflow 固定部分参数训练,只训练部分参数的实例

    在使用tensorflow来训练一个模型的时候,有时候需要依靠验证集来判断模型是否已经过拟合,是否需要停止训练. 1.首先想到的是用tf.placeholder()载入不同的数据来进行计算,比如 def inference(input_): """ this is where you put your graph. the following is just an example. """ conv1 = tf.layers.conv2d(inp

  • TensorFlow利用saver保存和提取参数的实例

    在训练循环中,定期调用 saver.save() 方法,向文件夹中写入包含了当前模型中所有可训练变量的 checkpoint 文件. saver.save(sess, FLAGS.train_dir, global_step=step) global_step是训练的第几步 保存参数: import tensorflow as tf W = tf.Variable([[1, 2, 3]], dtype=tf.float32) b = tf.Variable([[1]], dtype=tf.flo

  • tensorflow 模型权重导出实例

    tensorflow在保存权重模型时多使用tf.train.Saver().save 函数进行权重保存,保存的ckpt文件无法直接打开,不利于将模型权重导入到其他框架使用(如Caffe.Keras等). 好在tensorflow提供了相关函数 tf.train.NewCheckpointReader 可以对ckpt文件进行权重查看,因此可以通过该函数进行数据导出. import tensorflow as tf import h5py cpktLogFileName = r'./checkpoi

  • TensorFlow 实战之实现卷积神经网络的实例讲解

    本文根据最近学习TensorFlow书籍网络文章的情况,特将一些学习心得做了总结,详情如下.如有不当之处,请各位大拿多多指点,在此谢过. 一.相关性概念 1.卷积神经网络(ConvolutionNeural Network,CNN) 19世纪60年代科学家最早提出感受野(ReceptiveField).当时通过对猫视觉皮层细胞研究,科学家发现每一个视觉神经元只会处理一小块区域的视觉图像,即感受野.20世纪80年代,日本科学家提出神经认知机(Neocognitron)的概念,被视为卷积神经网络最初

  • Pytorch模型参数的保存和加载

    目录 一.前言 二.参数保存 三.参数的加载 四.保存和加载整个模型 五.总结 一.前言 在模型训练完成后,我们需要保存模型参数值用于后续的测试过程.由于保存整个模型将耗费大量的存储,故推荐的做法是只保存参数,使用时只需在建好模型的基础上加载. 通常来说,保存的对象包括网络参数值.优化器参数值.epoch值等.本文将简单介绍保存和加载模型参数的方法,同时也给出保存整个模型的方法供大家参考. 二.参数保存 在这里我们使用 torch.save() 函数保存模型参数: import torch pa

  • TensorFlow实现Softmax回归模型

    一.概述及完整代码 对MNIST(MixedNational Institute of Standard and Technology database)这个非常简单的机器视觉数据集,Tensorflow为我们进行了方便的封装,可以直接加载MNIST数据成我们期望的格式.本程序使用Softmax Regression训练手写数字识别的分类模型. 先看完整代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist imp

随机推荐