tensorflow 固定部分参数训练,只训练部分参数的实例

在使用tensorflow来训练一个模型的时候,有时候需要依靠验证集来判断模型是否已经过拟合,是否需要停止训练。

1.首先想到的是用tf.placeholder()载入不同的数据来进行计算,比如

def inference(input_):
  """
  this is where you put your graph.
  the following is just an example.
  """

  conv1 = tf.layers.conv2d(input_)

  conv2 = tf.layers.conv2d(conv1)

  return conv2

input_ = tf.placeholder()
output = inference(input_)
...
calculate_loss_op = ...
train_op = ...
...

with tf.Session() as sess:
  sess.run([loss, train_op], feed_dict={input_: train_data})

  if validation == True:
    sess.run([loss], feed_dict={input_: validate_date})

这种方式很简单,也很直接了然。

2.但是,如果处理的数据量很大的时候,使用 tf.placeholder() 来载入数据会严重地拖慢训练的进度,因此,常用tfrecords文件来读取数据。

此时,很容易想到,将不同的值传入inference()函数中进行计算。

train_batch, label_batch = decode_train()
val_train_batch, val_label_batch = decode_validation()

train_result = inference(train_batch)
...
loss = ..
train_op = ...
...

if validation == True:
  val_result = inference(val_train_batch)
  val_loss = ..

with tf.Session() as sess:
  sess.run([loss, train_op])

  if validation == True:
    sess.run([val_result, val_loss])

这种方式看似能够直接调用inference()来对验证数据进行前向传播计算,但是,实则会在原图上添加上许多新的结点,这些结点的参数都是需要重新初始化的,也是就是说,验证的时候并不是使用训练的权重。

3.用一个tf.placeholder来控制是否训练、验证。

def inference(input_):
  ...
  ...
  ...

  return inference_result

train_batch, label_batch = decode_train()
val_batch, val_label = decode_validation()

is_training = tf.placeholder(tf.bool, shape=())

x = tf.cond(is_training, lambda: train_batch, lambda: val_batch)
y = tf.cond(is_training, lambda: train_label, lambda: val_label)

logits = inference(x)
loss = cal_loss(logits, y)
train_op = optimize(loss)

with tf.Session() as sess:

  loss, _ = sess.run([loss, train_op], feed_dict={is_training: True})

  if validation == True:
    loss = sess.run(loss, feed_dict={is_training: False})

使用这种方式就可以在一个大图里创建一个分支条件,从而通过控制placeholder来控制是否进行验证。

以上这篇tensorflow 固定部分参数训练,只训练部分参数的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • 从训练好的tensorflow模型中打印训练变量实例

    从tensorflow 训练后保存的模型中打印训变量:使用tf.train.NewCheckpointReader() import tensorflow as tf reader = tf.train.NewCheckpointReader('path/alexnet/model-330000') dic = reader.get_variable_to_shape_map() print dic 打印变量 w = reader.get_tensor("fc1/W") print t

  • tensorflow实现打印ckpt模型保存下的变量名称及变量值

    有时候会需要通过从保存下来的ckpt文件来观察其保存下来的训练完成的变量值. ckpt文件名列表:(一般是三个文件) xxxxx.ckpt.data-00000-of-00001 xxxxx.ckpt.index xxxxx.ckpt.meta import os from tensorflow.python import pywrap_tensorflow checkpoint_path = os.path.join("文件夹路径", "xxxxx.ckpt")

  • 解决Pycharm的项目目录突然消失的问题

    今天在玩pycharm的时候不知道按了其中什么按钮,然后我们的项目目录全部都不见了(一开始还不知道这个叫做项目目录)然后自己捣鼓了好久各个窗口的打开关闭,终于最后被我发现了什么- 1. pycharm的项目全部不见了 自己作死不知道按了什么按钮,然后我们的项目目录变成这样了,对于有点强迫症的我们来说实在是太难受了.点击子文件还得一个一个找. 2. 问题的出现的原因 其实我们应该是按了project->mark directory as->exclude 然后就变成这样子的结果了. 3. 解决之

  • tensorflow 只恢复部分模型参数的实例

    我就废话不多说了,直接上代码吧! import tensorflow as tf def model_1(): with tf.variable_scope("var_a"): a = tf.Variable(initial_value=[1, 2, 3], name="a") vars = [var for var in tf.trainable_variables() if var.name.startswith("var_a")] prin

  • 详解tensorflow训练自己的数据集实现CNN图像分类

    利用卷积神经网络训练图像数据分为以下几个步骤 1.读取图片文件 2.产生用于训练的批次 3.定义训练的模型(包括初始化参数,卷积.池化层等参数.网络) 4.训练 1 读取图片文件 def get_files(filename): class_train = [] label_train = [] for train_class in os.listdir(filename): for pic in os.listdir(filename+train_class): class_train.app

  • tensorflow 固定部分参数训练,只训练部分参数的实例

    在使用tensorflow来训练一个模型的时候,有时候需要依靠验证集来判断模型是否已经过拟合,是否需要停止训练. 1.首先想到的是用tf.placeholder()载入不同的数据来进行计算,比如 def inference(input_): """ this is where you put your graph. the following is just an example. """ conv1 = tf.layers.conv2d(inp

  • pytorch 在网络中添加可训练参数,修改预训练权重文件的方法

    实践中,针对不同的任务需求,我们经常会在现成的网络结构上做一定的修改来实现特定的目的. 假如我们现在有一个简单的两层感知机网络: # -*- coding: utf-8 -*- import torch from torch.autograd import Variable import torch.optim as optim x = Variable(torch.FloatTensor([1, 2, 3])).cuda() y = Variable(torch.FloatTensor([4,

  • keras读取训练好的模型参数并把参数赋值给其它模型详解

    介绍 本博文中的代码,实现的是加载训练好的模型model_halcon_resenet.h5,并把该模型的参数赋值给两个不同的新的model. 函数式模型 官网上给出的调用一个训练好模型,并输出任意层的feature. model = Model(inputs=base_model.input, outputs=base_model.get_layer('block4_pool').output) 但是这有一个问题,就是新的model,如果输入inputs和训练好的model的inputs大小不

  • C#中#define后面只加一个参数的解释

    #define只加一个参数 的解释 <stdio.h> 里有: #ifndef __STDIO_H #define __STDIO_H 这个__STDIO_H代表什么?而define的用法不是后面加两个字符串吗,它这里却只加一个字符串,是什么意思? 还有很多头文件里都有如下语句 #if __STDC__ #define _Cdecl #else #define _Cdecl cdecl #endif __stdc__,cdecl代表什么? 比方说你#include进来一个stdio.h,再#i

  • tensorflow 2.0模式下训练的模型转成 tf1.x 版本的pb模型实例

    升级到tf 2.0后, 训练的模型想转成1.x版本的.pb模型, 但之前提供的通过ckpt转pb模型的方法都不可用(因为保存的ckpt不再有.meta)文件, 尝试了好久, 终于找到了一个方法可以迂回转到1.x版本的pb模型. Note: 本方法首先有些要求需要满足: 可以拿的到模型的网络结构定义源码 网络结构里面的所有操作都是通过tf.keras完成的, 不能出现类似tf.nn 的tensorflow自己的操作符 tf2.0下保存的模型是.h5格式的,并且仅保存了weights, 即通过mod

  • 深入讲解Python函数中参数的使用及默认参数的陷阱

    C++里函数可以设置缺省参数,Java不可以,只能通过重载的方式来实现,python里也可以设置默认参数,最大的好处就是降低函数难度,函数的定义只有一个,并且python是动态语言,在同一名称空间里不能有想多名称的函数,如果出现了,那么后出现的会覆盖前面的函数. def power(x, n=2): s = 1 while n > 0: n = n - 1 s = s * x return s 看看结果: >>> power(5) 25 >>> power(5,3

  • javascript replace()第二个参数为函数时的参数用法

    javascript的replace()第二个参数为函数时的参数: replace()函数具有替换功能,它可以具有两个参数,第一个参数可以是要被替换的字符串或者匹配要被替换字符串的正则表达式,第二个参数可以是替换文本或者一个函数,下面看一下关于replace()函数的几个代码实例. 代码实例: 实例一: <script> var str="I love jb51 and you?"; console.log(str.replace("jb","

  • webview添加参数与修改请求头的user-agent实例

    前言 最近公司项目需求,在项目中嵌入h5页面,一般原生,看着感觉跟往常一样,一个地址就完全ok了,如果是这样那就没有这个博文的必要了! 项目的登录使用的token登录,在移动端的登录是原生的,但是h5也是有登录页面,这就需要控制token的过期时间了,但是想达到的网页访问使用网页的cookie,app登录使用的是app原生的登录token,在网页的cookie登录过期的时候,网页是可以正常退回登录页面,而在app嵌入的h5也需要根据token是否过期,决定是否返回登录页. 那么,问题就是在此产生

  • Spring/Spring Boot 中优雅地做参数校验拒绝 if/else 参数校验

    数据的校验的重要性就不用说了,即使在前端对数据进行校验的情况下,我们还是要对传入后端的数据再进行一遍校验,避免用户绕过浏览器直接通过一些 HTTP 工具直接向后端请求一些违法数据. 最普通的做法就像下面这样.我们通过 if/else 语句对请求的每一个参数一一校验. @RestController @RequestMapping("/api/person") public class PersonController { @PostMapping public ResponseEnti

  • Python详解argparse参数模块之命令行参数

    目录 前言 示例一:最简参数对象 示例二:整数求和 示例三:文件是否被篡改 自定义类型 choices选项限定 required必选参数 子命令 前言 help(argparse)查看说明文档,“argparse - Command-line parsing library”我们可以知道是一个命令行解析库,是关于参数解析相关的一个模块. 示例一:最简参数对象 先来一段简单的代码,快速熟知下这个参数是个啥.保存为t.py这样一个文件 import argparse parser = argpars

随机推荐