Tensorflow 自定义loss的情况下初始化部分变量方式

一般情况下,tensorflow里面变量初始化过程为:

  #variables ...........
  #.....................
  init = tf.initialize_all_variables()
  sess.run(init)

这里 tf.initialize_all_variables() 会初始化所有的变量。

实际过程中,假设有a, b, c三个变量,其中a已经被初始化了,只想单独初始化b,c,那么:

  #variables ...
  ...
  init = tf.variables_initializer([b,c])
  sess.run(init)

此外,如果自行修改了optimizer,如下代码就会报错:

  #definition of variables a, b, c ...
  ....
  my_optimizer = tf.train.RMSProp(learning_rate = 0.1).minimize(my_cost)
  init = tf.variables_initializer([b,c])
  sess.run(init)

这是因为自己定义的optimizer会生成新的variables,但是在init里面并没有初始化,所以无法访问,会报错。解决方法如下:

  a = tf.Variables(...)      #line N
  temp = set(tf.all_variables())
  b = tf.Variables(...)
  c = tf.Variables(...)
  #definition of my optimizer
  optimizer = tf.train.......
  init = tf.variables_initializer(set(tf.all_varialbles())-temp) # line M
  sess.run(init)

首先,temp = set(tf.all_variables()) 将该行(line N)代码之前的所有变量保存在temp中,接下来定义变量b, c,以及自定义的optimizer,然后 set(tf.all_varialbles()存储了改行(line M)之前的所有变量(包括optimizer生成的变量以及temp中所含的变量),set(tf.all_varialbles())-temp相减得到line N~M这几行定义的变量。

以上这篇Tensorflow 自定义loss的情况下初始化部分变量方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • TensorFlow损失函数专题详解

    一.分类问题损失函数--交叉熵(crossentropy) 交叉熵刻画了两个概率分布之间的距离,是分类问题中使用广泛的损失函数.给定两个概率分布p和q,交叉熵刻画的是两个概率分布之间的距离: 我们可以通过Softmax回归将神经网络前向传播得到的结果变成交叉熵要求的概率分布得分.在TensorFlow中,Softmax回归的参数被去掉了,只是一个额外的处理层,将神经网络的输出变成一个概率分布. 代码实现: import tensorflow as tf y_ = tf.constant([[1.

  • 关于tensorflow的几种参数初始化方法小结

    在tensorflow中,经常会遇到参数初始化问题,比如在训练自己的词向量时,需要对原始的embeddigs矩阵进行初始化,更一般的,在全连接神经网络中,每层的权值w也需要进行初始化. tensorlfow中应该有一下几种初始化方法 1. tf.constant_initializer() 常数初始化 2. tf.ones_initializer() 全1初始化 3. tf.zeros_initializer() 全0初始化 4. tf.random_uniform_initializer()

  • tensorflow实现tensor中满足某一条件的数值取出组成新的tensor

    首先使用tf.where()将满足条件的数值索引取出来,在numpy中,可以直接用矩阵引用索引将满足条件的数值取出来,但是在tensorflow中这样是不行的.所幸,tensorflow提供了tf.gather()和tf.gather_nd()函数. 看下面这一段代码: import tensorflow as tf sess = tf.Session() def get_tensor(): x = tf.random_uniform((5, 4)) ind = tf.where(x>0.5)

  • Tensorflow 自定义loss的情况下初始化部分变量方式

    一般情况下,tensorflow里面变量初始化过程为: #variables ........... #..................... init = tf.initialize_all_variables() sess.run(init) 这里 tf.initialize_all_variables() 会初始化所有的变量. 实际过程中,假设有a, b, c三个变量,其中a已经被初始化了,只想单独初始化b,c,那么: #variables ... ... init = tf.vari

  • mysql不重启的情况下修改参数变量

    通常来说,更新mysql配置my.cnf需要重启mysql才能生效,但是有些时候mysql在线上,不一定允许你重启,这时候应该怎么办呢? 看一个例子: mysql> show variables like 'log_slave_updates'; +-------------------+-------+| Variable_name     | Value |+-------------------+-------+| log_slave_updates | OFF   |+---------

  • Android不使用自定义布局情况下实现自定义通知栏图标的方法

    本文实例讲述了Android不使用自定义布局情况下实现自定义通知栏图标的方法.分享给大家供大家参考,具体如下: 自定义通知栏图标?不是很简单么.自定义布局都不在话下! 是的,有xml布局文件当然一切都很简单,如果不给你布局文件用呢? 听我慢慢道来! 首先怎么创建一个通知呢? 1.new 一个 复制代码 代码如下: Notification n = new Notification(android.R.drawable.ic_menu_share, null, System.currentTime

  • tensorflow 自定义损失函数示例代码

    这个自定义损失函数的背景:(一般回归用的损失函数是MSE, 但要看实际遇到的情况而有所改变) 我们现在想要做一个回归,来预估某个商品的销量,现在我们知道,一件商品的成本是1元,售价是10元. 如果我们用均方差来算的话,如果预估多一个,则损失一块钱,预估少一个,则损失9元钱(少赚的). 显然,我宁愿预估多了,也不想预估少了. 所以,我们就自己定义一个损失函数,用来分段地看,当yhat 比 y大时怎么样,当yhat比y小时怎么样. (yhat沿用吴恩达课堂中的叫法) import tensorflo

  • keras 自定义loss损失函数,sample在loss上的加权和metric详解

    首先辨析一下概念: 1. loss是整体网络进行优化的目标, 是需要参与到优化运算,更新权值W的过程的 2. metric只是作为评价网络表现的一种"指标", 比如accuracy,是为了直观地了解算法的效果,充当view的作用,并不参与到优化过程 在keras中实现自定义loss, 可以有两种方式,一种自定义 loss function, 例如: # 方式一 def vae_loss(x, x_decoded_mean): xent_loss = objectives.binary_

  • TensorFlow自定义损失函数来预测商品销售量

    在预测商品销量时,如果预测多了(预测值比真实销量大),商家损失的是生产商品的成本:而如果预测少了(预测值比真实销量小),损失的则是商品的利润.因为一般商品的成本和商品的利润不会严格相等,比如如果一个商品的成本是1元,但是利润是10元,那么少预测一个就少挣10元:而多预测一个才少挣1元,所以如果神经网络模型最小化的是均方误差损失函数,那么很有可能此模型就无法最大化预期的销售利润. 为了最大化预期利润,需要将损失函数和利润直接联系起来,需要注意的是,损失函数定义的是损失,所以要将利润最大化,定义的损

  • keras 自定义loss层+接受输入实例

    loss函数如何接受输入值 keras封装的比较厉害,官网给的例子写的云里雾里, 在stackoverflow找到了答案 You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function). def custom_loss_wrapper(input_

  • J2ee 高并发情况下监听器实例详解

    J2ee 高并发情况下监听器实例详解 引言:在高并发下限制最大并发次数,在web.xml中用过滤器设置参数(最大并发数),并设置其他相关参数.详细见代码. 第一步:配置web.xml配置,不懂的地方解释一下:参数50通过参数名maxConcurrent用在filter的实现类中获取,filter-class就是写的实现类, url-pattern就是限制并发时间的url,结束! <filter> <filter-name>ConcurrentCountFilter</filt

  • Android开发使用自定义view实现ListView下拉的视差特效功能

    本文实例讲述了Android开发使用自定义view实现ListView下拉的视差特效功能.分享给大家供大家参考,具体如下: 一.概述: 现在流型的APP如微信朋友圈,QQ空间,微博个人展示都有视差特效的影子. 如图:下拉图片会产生图片拉升的效果,放手后图片有弹回到原处: 那我们如何实现呢? 1)重写ListView控件: 2)重写里面的overScrollBy方法 3)在松手后执行值动画 二.具体实现: 1.创建ParallaListView 自定义ListView public class P

  • SpringCloud Zuul在何种情况下使用Hystrix及问题小结

    首先,引入spring-cloud-starter-zuul之后会间接引入: hystrix依赖已经引入,那么何种情况下使用hystrix呢? 在Zuul的自动配置类ZuulServerAutoConfiguration和ZuulProxyAutoConfiguration中总共会向Spring容器注入3个Zuul的RouteFilter,分别是 •SimpleHostRoutingFilter 简单路由,通过HttpClient向预定的URL发送请求 生效条件: RequestContext.

随机推荐