tensorflow 恢复指定层与不同层指定不同学习率的方法

如下所示:

#tensorflow 中从ckpt文件中恢复指定的层或将指定的层不进行恢复:
#tensorflow 中不同的layer指定不同的学习率

with tf.Graph().as_default():
		#存放的是需要恢复的层参数
	 variables_to_restore = []
	 #存放的是需要训练的层参数名,这里是没恢复的需要进行重新训练,实际上恢复了的参数也可以训练
  variables_to_train = []
  for var in slim.get_model_variables():
   excluded = False
   for exclusion in fine_tune_layers:
   #比如fine tune layer中包含logits,bottleneck
    if var.op.name.startswith(exclusion):
     excluded = True
     break
   if not excluded:
    variables_to_restore.append(var)
    #print('var to restore :',var)
   else:
    variables_to_train.append(var)
    #print('var to train: ',var)

  #这里省略掉一些步骤,进入训练步骤:
  #将variables_to_train,需要训练的参数给optimizer 的compute_gradients函数
  grads = opt.compute_gradients(total_loss, variables_to_train)
  #这个函数将只计算variables_to_train中的梯度
  #然后将梯度进行应用:
  apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
  #也可以直接调用opt.minimize(total_loss,variables_to_train)
  #minimize只是将compute_gradients与apply_gradients封装成了一个函数,实际上还是调用的这两个函数
  #如果在梯度里面不同的参数需要不同的学习率,那么可以:

  capped_grads_and_vars = []#[(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars]
  #update_gradient_vars是需要更新的参数,使用的是全局学习率
  #对于不是update_gradient_vars的参数,将其梯度更新乘以0.0001,使用基本上不动
 	for grad in grads:
 		for update_vars in update_gradient_vars:
 			if grad[1]==update_vars:
 				capped_grads_and_vars.append((grad[0],grad[1]))
 			else:
 				capped_grads_and_vars.append((0.0001*grad[0],grad[1]))

 	apply_gradient_op = opt.apply_gradients(capped_grads_and_vars, global_step=global_step)

 	#在恢复模型时:

  with sess.as_default():

   if pretrained_model:
    print('Restoring pretrained model: %s' % pretrained_model)
    init_fn = slim.assign_from_checkpoint_fn(
    pretrained_model,
    variables_to_restore)
    init_fn(sess)
   #这样就将指定的层参数没有恢复

以上这篇tensorflow 恢复指定层与不同层指定不同学习率的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • tensorflow saver 保存和恢复指定 tensor的实例讲解

    在实践中经常会遇到这样的情况: 1.用简单的模型预训练参数 2.把预训练的参数导入复杂的模型后训练复杂的模型 这时就产生一个问题: 如何加载预训练的参数. 下面就是我的总结. 为了方便说明,做一个假设:简单的模型只有一个卷基层,复杂模型有两个. 卷积层的实现代码如下: import tensorflow as tf # PS:本篇的重担是saver,不过为了方便阅读还是说明下参数 # 参数 # name:创建卷基层的代码这么多,必须要函数化,而为了防止变量冲突就需要用tf.name_scope

  • 浅谈Tensorflow模型的保存与恢复加载

    近期做了一些反垃圾的工作,除了使用常用的规则匹配过滤等手段,也采用了一些机器学习方法进行分类预测.我们使用TensorFlow进行模型的训练,训练好的模型需要保存,预测阶段我们需要将模型进行加载还原使用,这就涉及TensorFlow模型的保存与恢复加载. 总结一下Tensorflow常用的模型保存方式. 保存checkpoint模型文件(.ckpt) 首先,TensorFlow提供了一个非常方便的api,tf.train.Saver()来保存和还原一个机器学习模型. 模型保存 使用tf.trai

  • tensorflow 恢复指定层与不同层指定不同学习率的方法

    如下所示: #tensorflow 中从ckpt文件中恢复指定的层或将指定的层不进行恢复: #tensorflow 中不同的layer指定不同的学习率 with tf.Graph().as_default(): #存放的是需要恢复的层参数 variables_to_restore = [] #存放的是需要训练的层参数名,这里是没恢复的需要进行重新训练,实际上恢复了的参数也可以训练 variables_to_train = [] for var in slim.get_model_variable

  • tensorflow获取预训练模型某层参数并赋值到当前网络指定层方式

    已经有了一个预训练的模型,我需要从其中取出某一层,把该层的weights和biases赋值到新的网络结构中,可以使用tensorflow中的pywrap_tensorflow(用来读取预训练模型的参数值)结合Session.assign()进行操作. 这种需求即预训练模型可能为单分支网络,当前网络为多分支,我需要把单分支A复用到到多个分支去(B,C,D). 先导入对应的工具包 from tensorflow.python import pywrap_tensorflow 接下来的操作在一个tf.

  • jQuery实现指定区域外单击关闭指定层的方法【经典】

    本文实例讲述了jQuery实现指定区域外单击关闭指定层的方法.分享给大家供大家参考,具体如下: 在页面上指定区域外单击,关闭层.常见效果为弹出层外单击,关闭弹出层.今天遇到一个这样的效果,用jQuery实现起来挺简单的,顺便复习了一下相关知识. $(document).mouseup(function(e){ if($(e.target).parent("#big_map").length==0){ $("#big_map").hide("fast&quo

  • pytorch载入预训练模型后,实现训练指定层

    1.有了已经训练好的模型参数,对这个模型的某些层做了改变,如何利用这些训练好的模型参数继续训练: pretrained_params = torch.load('Pretrained_Model') model = The_New_Model(xxx) model.load_state_dict(pretrained_params.state_dict(), strict=False) strict=False 使得预训练模型参数中和新模型对应上的参数会被载入,对应不上或没有的参数被抛弃. 2.

  • python删除指定类型(或非指定)的文件实例详解

    本文实例分析了python删除指定类型(或非指定)的文件用法.分享给大家供大家参考.具体如下: 如下,删除目录下非源码文件 import os import string def del_files(dir,topdown=True): for root, dirs, files in os.walk(dir, topdown): for name in files: pathname = os.path.splitext(os.path.join(root, name)) if (pathna

  • Asp.net获取服务器指定文件夹目录文件并提供下载的方法

    本文实例讲述了Asp.net获取服务器指定文件夹目录文件并提供下载的方法.分享给大家供大家参考.具体实现方法如下: 复制代码 代码如下: string dirPath = HttpContext.Current.Server.MapPath("uploads/"); if (Directory.Exists(dirPath)) {        //获得目录信息        DirectoryInfo dir = new DirectoryInfo(dirPath);       

  • PHP递归遍历指定目录的文件并统计文件数量的方法

    本文实例讲述了PHP递归遍历指定目录的文件并统计文件数量的方法.分享给大家供大家参考.具体实现方法如下: <?php //递归函数实现遍历指定文件下的目录与文件数量 function total($dirname,&$dirnum,&$filenum){ $dir=opendir($dirname); echo readdir($dir)."<br>"; //读取当前目录文件 echo readdir($dir)."<br>&qu

  • python对指定目录下文件进行批量重命名的方法

    本文实例讲述了python对指定目录下文件进行批量重命名的方法.分享给大家供大家参考.具体如下: 这段python代码可对c:\temp目录下的所有文件名为"scroll_1"文件替换为"scroll_00" import os path = 'c:\\temp' for file in os.listdir(path): if os.path.isfile(os.path.join(path,file))==True: newname = file.replace

  • Android设置TextView显示指定个数字符,超过部分显示...(省略号)的方法

    本文实例讲述了Android设置TextView显示指定个数字符,超过部分显示...(省略号)的方法.分享给大家供大家参考,具体如下: 一.问题: 今天在公司遇到一个需求:TextView设置最多显示8个字符,超过部分显示...(省略号) 二.解决方法: 网上找了很多资料,有人说分别设置TextView的android:signature="true",并且设置android:ellipsize="end";但是我试了,并没有成功,最后自己试出一种方式如下:供大家参

  • php使用指定编码导出mysql数据到csv文件的方法

    本文实例讲述了php使用指定编码导出mysql数据到csv文件的方法.分享给大家供大家参考.具体实现方法如下: <?php /* * PHP code to export MySQL data to CSV * * Sends the result of a MySQL query as a CSV file for download * Easy to convert to UTF-8. */ /* * establish database connection */ $conn = mysq

随机推荐