Tensorflow实现部分参数梯度更新操作

在深度学习中,迁移学习经常被使用,在大数据集上预训练的模型迁移到特定的任务,往往需要保持模型参数不变,而微调与任务相关的模型层。

本文主要介绍,使用tensorflow部分更新模型参数的方法。

1. 根据Variable scope剔除需要固定参数的变量

def get_variable_via_scope(scope_lst):
  vars = []
  for sc in scope_lst:
    sc_variable = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope=scope)
    vars.extend(sc_variable)
  return vars

trainable_vars = tf.trainable_variables()
no_change_scope = ['your_unchange_scope_name']

no_change_vars = get_variable_via_scope(no_change_scope)

for v in no_change_vars:
  trainable_vars.remove(v)

grads, _ = tf.gradients(loss, trainable_vars)

optimizer = tf.train.AdamOptimizer(lr)

train_op = optimizer.apply_gradient(zip(grads, trainable_vars), global_step=global_step)

2. 使用tf.stop_gradient()函数

在建立Graph过程中使用该函数,非常简洁地避免了使用scope获取参数

3. 一个矩阵中部分行或列参数更新

如果一个矩阵,只有部分行或列需要更新参数,其它保持不变,该场景很常见,例如word embedding中,一些预定义的领域相关词保持不变(使用领域相关word embedding初始化),而另一些通用词变化。

import tensorflow as tf
import numpy as np

def entry_stop_gradients(target, mask):
  mask_h = tf.abs(mask-1)
  return tf.stop_gradient(mask_h * target) + mask * target

mask = np.array([1., 0, 1, 1, 0, 0, 1, 1, 0, 1])
mask_h = np.abs(mask-1)

emb = tf.constant(np.ones([10, 5]))

matrix = entry_stop_gradients(emb, tf.expand_dims(mask,1))

parm = np.random.randn(5, 1)
t_parm = tf.constant(parm)

loss = tf.reduce_sum(tf.matmul(matrix, t_parm))
grad1 = tf.gradients(loss, emb)
grad2 = tf.gradients(loss, matrix)
print matrix
with tf.Session() as sess:
  print sess.run(loss)
  print sess.run([grad1, grad2])

以上这篇Tensorflow实现部分参数梯度更新操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • tensorflow 固定部分参数训练,只训练部分参数的实例

    在使用tensorflow来训练一个模型的时候,有时候需要依靠验证集来判断模型是否已经过拟合,是否需要停止训练. 1.首先想到的是用tf.placeholder()载入不同的数据来进行计算,比如 def inference(input_): """ this is where you put your graph. the following is just an example. """ conv1 = tf.layers.conv2d(inp

  • 解决Pycharm的项目目录突然消失的问题

    今天在玩pycharm的时候不知道按了其中什么按钮,然后我们的项目目录全部都不见了(一开始还不知道这个叫做项目目录)然后自己捣鼓了好久各个窗口的打开关闭,终于最后被我发现了什么- 1. pycharm的项目全部不见了 自己作死不知道按了什么按钮,然后我们的项目目录变成这样了,对于有点强迫症的我们来说实在是太难受了.点击子文件还得一个一个找. 2. 问题的出现的原因 其实我们应该是按了project->mark directory as->exclude 然后就变成这样子的结果了. 3. 解决之

  • Tensorflow实现部分参数梯度更新操作

    在深度学习中,迁移学习经常被使用,在大数据集上预训练的模型迁移到特定的任务,往往需要保持模型参数不变,而微调与任务相关的模型层. 本文主要介绍,使用tensorflow部分更新模型参数的方法. 1. 根据Variable scope剔除需要固定参数的变量 def get_variable_via_scope(scope_lst): vars = [] for sc in scope_lst: sc_variable = tf.get_collection(tf.GraphKeys.TRAINAB

  • Hibernate对数据库删除、查找、更新操作实例代码

    本节继续hibernate对数据库的其他操作,删除.查询.修改. Hibernate对数据删除操作 删除User表中个一条数据,是需要更具User表的主键id值来删除的.首先根据id值向数据库中查询出来对应的对象.可以采用两种方式一是session的get方法,一个是session的load方法. Session的Get方法:调用这个方法会返回一个Object对象.然后我们对其强制转换.Useruser = (User)session.get(User.class," 402881e5441c0

  • MongoDB中文档的更新操作示例详解

    前言 在MongoDB中,更新单个doc的操作是原子性的.默认情况下,如果一个update操作更新多个doc,那么对每个doc的更新是原子性的,但是对整个update 操作而言,不是原子性的,可能存在前面的doc更新成功,而后面的doc更新失败的情况.由于更新单个doc的操作是原子性的,如果两个更新同时发生,那么一个更新操作会阻塞另外一个,doc的最终结果值是由时间靠后的更新操作决定的. 我们在前面的文章中提到过文档的基本的增删改查操作,MongoDB中提供的增删改查的语法非常丰富,不清楚的朋友

  • pytorch 自定义参数不更新方式

    nn.Module中定义参数:不需要加cuda,可以求导,反向传播 class BiFPN(nn.Module): def __init__(self, fpn_sizes): self.w1 = nn.Parameter(torch.rand(1)) print("no---------------------------------------------------",self.w1.data, self.w1.grad) 下面这个例子说明中间变量可能没有梯度,但是最终变量有梯度

  • mybatis-plus update更新操作的三种方式(小结)

    目录 1.@ 根据id更新 2.@ 条件构造器作为参数进行更新 3.@ lambda构造器 mybatisplus update语句为null时没有拼接上去 1.@ 根据id更新 User user = new User(); user.setUserId(1); user.setAge(29); userMapper.updateById(user); 2.@ 条件构造器作为参数进行更新 //把名字为rhb的用户年龄更新为18,其他属性不变 UpdateWrapper<User> updat

  • MySQL中实现插入或更新操作(类似Oracle的merge语句)

    如果需要在MySQL中实现记录不存在则insert,不存在则update操作.可以使用以下语句: 更新一个字段: INSERT INTO tbl (columnA,columnB,columnC) VALUES (1,2,3) ON DUPLICATE KEY UPDATE columnA=IF(columnB>0,1,columnA) 更新多个字段: INSERT INTO tbl (columnA,columnB,columnC) VALUES (1,2,3) ON DUPLICATE KE

  • MySql删除和更新操作对性能有影响吗

    删除和更新操作的开销往往比插入高,所以一个好的设计需要减少对数据库的更新和删除操作. 3.1更新操作 数据库的更新操作会带来一连串的"效应":更新操作需要记录日志(以便错误时回滚):更新可变长字段(如,varchar类型)会带来数据物理存储的变化(记录的移动):更新索引字段会导致索引重建:更新主键会导致数据重组等.这一切不但会造成更新操作本身效率低,而且由于磁片碎片的产生会造成以后查询性能的降低.为了应对这一情况,有两种策略:一.减少更新次数,把多个字段的更新写到同一个语句里:二.避免

  • 使用Java对数据库进行基本的查询和更新操作

    数据库查询 利用Connection对象的createStatement方法建立Statement对象,利用Statement对象的executeQuery()方法执行SQL查询语句进行查询,返回结果集,再形如getXXX()的方法从结果集中读取数据.经过这样的一系列步骤就能实现对数据库的查询. [例]Java应用程序访问数据库.应用程序打开考生信息表ksInfo,从中取出考生的各项信息.设考生信息数据库的结构如下: import java.awt.*; import java.awt.even

  • 微信小程序学习笔记之跳转页面、传递参数获得数据操作图文详解

    本文实例讲述了微信小程序学习笔记之跳转页面.传递参数获得数据操作.分享给大家供大家参考,具体如下: 前面一篇介绍了微信小程序表单提交与PHP后台数据交互处理.现在需要实现点击博客标题或缩略图,跳转到博客详情页面. 开始想研究一下微信小程序的web-view组件跳转传参,把网页嵌入到小程序,结果看到官方文档的一句话打消了念头,因为没有认证...... [方法一 使用navigator组件跳转传参] 前台博客列表页面data.wxml:(后台数据交互参考上一篇) <view wx:for="{

  • 对pytorch中的梯度更新方法详解

    背景 使用pytorch时,有一个yolov3的bug,我认为涉及到学习率的调整.收集到tencent yolov3和mxnet开源的yolov3,两个优化器中的学习率设置不一样,而且使用GPU数目和batch的更新也不太一样.据此,我简单的了解了下pytorch的权重梯度的更新策略,看看能否一窥究竟. 对代码说明 共三个实验,分布写在代码中的(一)(二)(三)三个地方.运行实验时注释掉其他两个 实验及其结果 实验(三): 不使用zero_grad()时,grad累加在一起,官网是使用accum

随机推荐