tensorflow获取预训练模型某层参数并赋值到当前网络指定层方式
已经有了一个预训练的模型,我需要从其中取出某一层,把该层的weights和biases赋值到新的网络结构中,可以使用tensorflow中的pywrap_tensorflow(用来读取预训练模型的参数值)结合Session.assign()进行操作。
这种需求即预训练模型可能为单分支网络,当前网络为多分支,我需要把单分支A复用到到多个分支去(B,C,D)。
先导入对应的工具包
from tensorflow.python import pywrap_tensorflow
接下来的操作在一个tf.Session中进行
reader = pywrap_tensorflow.NewCheckpointReader(pre_train_model_path) # 获取当前图可训练变量 trainable_variables = tf.trainable_variables() # 需要赋值的当前网络层变量,这里只是随便起的名字。 restore_v_target_name = "fc_target" # 需要的预训练模型中的某层的名字 restore_v_source_name = "fc_source" for v in trainable_variables: if restore_v_target_name == v.name: # 回复weights和biases sess.run( tf.assign(v, reader.get_tensor(restore_v_source_name + "/weights"))) if "weights" in v.name else sess.run( tf.assign(v, reader.get_tensor(restore_v_source_name + "/biases")))
以上这篇tensorflow获取预训练模型某层参数并赋值到当前网络指定层方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。
相关推荐
-
tensorflow 固定部分参数训练,只训练部分参数的实例
在使用tensorflow来训练一个模型的时候,有时候需要依靠验证集来判断模型是否已经过拟合,是否需要停止训练. 1.首先想到的是用tf.placeholder()载入不同的数据来进行计算,比如 def inference(input_): """ this is where you put your graph. the following is just an example. """ conv1 = tf.layers.conv2d(inp
-
解决Pycharm的项目目录突然消失的问题
今天在玩pycharm的时候不知道按了其中什么按钮,然后我们的项目目录全部都不见了(一开始还不知道这个叫做项目目录)然后自己捣鼓了好久各个窗口的打开关闭,终于最后被我发现了什么- 1. pycharm的项目全部不见了 自己作死不知道按了什么按钮,然后我们的项目目录变成这样了,对于有点强迫症的我们来说实在是太难受了.点击子文件还得一个一个找. 2. 问题的出现的原因 其实我们应该是按了project->mark directory as->exclude 然后就变成这样子的结果了. 3. 解决之
-
tensorflow模型保存、加载之变量重命名实例
话不多说,干就完了. 变量重命名的用处? 简单定义:简单来说就是将模型A中的参数parameter_A赋给模型B中的parameter_B 使用场景:当需要使用已经训练好的模型参数,尤其是使用别人训练好的模型参数时,往往别人模型中的参数命名方式与自己当前的命名方式不同,所以在加载模型参数时需要对参数进行重命名,使得代码更简洁易懂. 实现方法: 1).模型保存 import os import tensorflow as tf weights = tf.Variable(initial_value
-
浅谈tensorflow中张量的提取值和赋值
tf.gather和gather_nd从params中收集数值,tf.scatter_nd 和 tf.scatter_nd_update用updates更新某一张量.严格上说,tf.gather_nd和tf.scatter_nd_update互为逆操作. 已知数值的位置,从张量中提取数值:tf.gather, tf.gather_nd tf.gather indices每个元素(标量)是params某个axis的索引,tf.gather_nd 中indices最后一个阶对应于索引值. tf.ga
-
Tensorflow实现部分参数梯度更新操作
在深度学习中,迁移学习经常被使用,在大数据集上预训练的模型迁移到特定的任务,往往需要保持模型参数不变,而微调与任务相关的模型层. 本文主要介绍,使用tensorflow部分更新模型参数的方法. 1. 根据Variable scope剔除需要固定参数的变量 def get_variable_via_scope(scope_lst): vars = [] for sc in scope_lst: sc_variable = tf.get_collection(tf.GraphKeys.TRAINAB
-
tensorflow模型继续训练 fineturn实例
解决tensoflow如何在已训练模型上继续训练fineturn的问题. 训练代码 任务描述: x = 3.0, y = 100.0, 运算公式 x×W+b = y,求 W和b的最优解. # -*- coding: utf-8 -*-) import tensorflow as tf # 声明占位变量x.y x = tf.placeholder("float", shape=[None, 1]) y = tf.placeholder("float", [None,
-
tensorflow获取预训练模型某层参数并赋值到当前网络指定层方式
已经有了一个预训练的模型,我需要从其中取出某一层,把该层的weights和biases赋值到新的网络结构中,可以使用tensorflow中的pywrap_tensorflow(用来读取预训练模型的参数值)结合Session.assign()进行操作. 这种需求即预训练模型可能为单分支网络,当前网络为多分支,我需要把单分支A复用到到多个分支去(B,C,D). 先导入对应的工具包 from tensorflow.python import pywrap_tensorflow 接下来的操作在一个tf.
-
layer子层给父层页面元素赋值,以达到向父层页面传值的效果实例
父层: jsp中: //页面上添加一个隐藏的输入框待用于被子层设置value,从而将子层的数据传递到此页面 <input type="hidden" id="getValue" name="getValue" value="" /> js代码: //设置function,当执行时,弹出子窗口并传递当前窗口名称 //弹出子窗口(选择商家) function choseMerchant() { //获取当前窗口名称 v
-
Pytorch加载部分预训练模型的参数实例
前言 自从从深度学习框架caffe转到Pytorch之后,感觉Pytorch的优点妙不可言,各种设计简洁,方便研究网络结构修改,容易上手,比TensorFlow的臃肿好多了.对于深度学习的初学者,Pytorch值得推荐.今天主要主要谈谈Pytorch是如何加载预训练模型的参数以及代码的实现过程. 直接加载预选脸模型 如果我们使用的模型和预训练模型完全一样,那么我们就可以直接加载别人的模型,还有一种情况,我们在训练自己模型的过程中,突然中断了,但只要我们保存了之前的模型的参数也可以使用下面的代码直
-
使用Keras预训练模型ResNet50进行图像分类方式
Keras提供了一些用ImageNet训练过的模型:Xception,VGG16,VGG19,ResNet50,InceptionV3.在使用这些模型的时候,有一个参数include_top表示是否包含模型顶部的全连接层,如果包含,则可以将图像分为ImageNet中的1000类,如果不包含,则可以利用这些参数来做一些定制的事情. 在运行时自动下载有可能会失败,需要去网站中手动下载,放在"~/.keras/models/"中,使用WinPython则在"settings/.ke
-
Tensorflow加载预训练模型和保存模型的实例
使用tensorflow过程中,训练结束后我们需要用到模型文件.有时候,我们可能也需要用到别人训练好的模型,并在这个基础上再次训练.这时候我们需要掌握如何操作这些模型数据.看完本文,相信你一定会有收获! 1 Tensorflow模型文件 我们在checkpoint_dir目录下保存的文件结构如下: |--checkpoint_dir | |--checkpoint | |--MyModel.meta | |--MyModel.data-00000-of-00001 | |--MyModel.in
-
Tensorflow加载Vgg预训练模型操作
很多深度神经网络模型需要加载预训练过的Vgg参数,比如说:风格迁移.目标检测.图像标注等计算机视觉中常见的任务.那么到底如何加载Vgg模型呢?Vgg文件的参数到底有何意义呢?加载后的模型该如何使用呢? 本文将以Vgg19为例子,详细说明Tensorflow如何加载Vgg预训练模型. 实验环境 GTX1050-ti, cuda9.0 Window10, Tensorflow 1.12 展示Vgg19构造 import tensorflow as tf import numpy as np impo
-
浅谈Tensorflow加载Vgg预训练模型的几个注意事项
写这个博客的关键Bug: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16, bfloat16, float32, float64.本博客将围绕 加载图片 和 保存图片到本地 来详细解释和解决上述的Bug及其引出来的一系列Bug. 加载图片 首先,造成上述Bug的代码如下所示 image_path = "data/test.jpg" # 本地的测试图片
-
Keras 实现加载预训练模型并冻结网络的层
在解决一个任务时,我会选择加载预训练模型并逐步fine-tune.比如,分类任务中,优异的深度学习网络有很多. ResNet, VGG, Xception等等... 并且这些模型参数已经在imagenet数据集中训练的很好了,可以直接拿过来用. 根据自己的任务,训练一下最后的分类层即可得到比较好的结果.此时,就需要"冻结"预训练模型的所有层,即这些层的权重永不会更新. 以Xception为例: 加载预训练模型: from tensorflow.python.keras.applicat
-
pytorch载入预训练模型后,实现训练指定层
1.有了已经训练好的模型参数,对这个模型的某些层做了改变,如何利用这些训练好的模型参数继续训练: pretrained_params = torch.load('Pretrained_Model') model = The_New_Model(xxx) model.load_state_dict(pretrained_params.state_dict(), strict=False) strict=False 使得预训练模型参数中和新模型对应上的参数会被载入,对应不上或没有的参数被抛弃. 2.
-
pytorch 预训练模型读取修改相关参数的填坑问题
pytorch 预训练模型读取修改相关参数的填坑 修改部分层,仍然调用之前的模型参数. resnet = resnet50(pretrained=False) resnet.load_state_dict(torch.load(args.predir)) res_conv31 = Bottleneck_dilated(1024, 256,dilated_rate = 2) print("---------------------",res_conv31) print("---
随机推荐
- 将MongoDB作为Redis式的内存数据库的使用方法
- Nginx伪静态配置和常用Rewrite伪静态规则集锦
- Oracle数据库按时间进行分组统计数据的方法
- Python中规范定义命名空间的一些建议
- mysql jdbc连接步骤及常见参数
- js兼容IE6,IE7菜单高亮显示效果代码
- Android仿微信QQ设置图形头像裁剪功能
- Android Map新用法:MapFragment应用介绍
- Java过滤器filter_动力节点Java学院整理
- js操作数据库实现注册和登陆的简单实例
- C#中私有构造函数的特点和用途实例解析
- AndroidStudio升级到3.0的新特性和注意事项小结
- Java实现终止线程池中正在运行的定时任务
- python爬取内容存入Excel实例
- python实现数据分析与建模
- java使用多线程读取超大文件
- Spring Cloud入门教程之Zuul实现API网关与请求过滤
- C#统计字符串的方法
- Android studio设置文件头定制代码注释的方法
- Vue 菜单栏点击切换单个class(高亮)的方法