Pytorch获取无梯度TorchTensor中的值

获取无梯度Tensor

遇到的问题:

使用两个网络并行运算,一个网络的输出值要给另一个网络反馈。而反馈的输出值带有网络权重的梯度,即grad_fn=<XXXBackward0>.

这时候如果把反馈值扔到第二网络中更新,会出现第一个计算图丢失无法更新的错误。哎哟喂,我根本不需要第一个网络的梯度好吗?

一开始用了一个笨办法,先转numpy,然后再转回torch.Tensor。因为numpy数据是不带梯度的。

但是我的原始tensor的放在cuda上的,

cuda的张量是不能直接转Tensor,所以

t_error = td_error.cuda().data.cpu().numpy()
t_error = torch.FloatTensor(t_error).to(device)

从cuda转回了cpu,变成numpy,又转成了tensor,又回到了cuda上,坑爹呢这是,可能只有我才能写出如此低效的辣鸡代码了。

后来发现,其实直接在返回的时候添加

with torch.no_grad():
 td_error = reward + GAMMA * v_ - v

即可.

补充:在pytorch中取一个tensor的均值,然后该张量中的所有值与其对比!

Pytorch中的Tensor的shape是(B, C, W, H),

对该tensor取均值并与所有值做对比代码如下:

C, H, W = tensor.shape[1], tensor.shape[2], tensor.shape[3]
for c in range(C):
 mean = torch.mean(x[0][c])
 for h in range(H):
  for w in range(W):
  if x[0][c][h][w] >= mean:
  x[0][c][h][w] = mean

以上为个人经验,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • pytorch 获取tensor维度信息示例

    我就废话不多说了,直接上代码吧! >>> import torch >>> from torch.autograd import Variable >>> from torch import IntTensor >>> var = Variable(IntTensor([[1,0],[0,1]])) >>> var Variable containing: 1 0 0 1 [torch.IntTensor of si

  • 对pytorch中的梯度更新方法详解

    背景 使用pytorch时,有一个yolov3的bug,我认为涉及到学习率的调整.收集到tencent yolov3和mxnet开源的yolov3,两个优化器中的学习率设置不一样,而且使用GPU数目和batch的更新也不太一样.据此,我简单的了解了下pytorch的权重梯度的更新策略,看看能否一窥究竟. 对代码说明 共三个实验,分布写在代码中的(一)(二)(三)三个地方.运行实验时注释掉其他两个 实验及其结果 实验(三): 不使用zero_grad()时,grad累加在一起,官网是使用accum

  • 在PyTorch中Tensor的查找和筛选例子

    本文源码基于版本1.0,交互界面基于0.4.1 import torch 按照指定轴上的坐标进行过滤 index_select() 沿着某tensor的一个轴dim筛选若干个坐标 >>> x = torch.randn(3, 4) # 目标矩阵 >>> x tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], [-0.4664, 0.2647, -0.1228, -1.1068], [-1.1734, -0.6571, 0.7230,

  • 详解PyTorch中Tensor的高阶操作

    条件选取:torch.where(condition, x, y) → Tensor 返回从 x 或 y 中选择元素的张量,取决于 condition 操作定义: 举个例子: >>> import torch >>> c = randn(2, 3) >>> c tensor([[ 0.0309, -1.5993, 0.1986], [-0.0699, -2.7813, -1.1828]]) >>> a = torch.ones(2,

  • Pytorch获取无梯度TorchTensor中的值

    获取无梯度Tensor 遇到的问题: 使用两个网络并行运算,一个网络的输出值要给另一个网络反馈.而反馈的输出值带有网络权重的梯度,即grad_fn=<XXXBackward0>. 这时候如果把反馈值扔到第二网络中更新,会出现第一个计算图丢失无法更新的错误.哎哟喂,我根本不需要第一个网络的梯度好吗? 一开始用了一个笨办法,先转numpy,然后再转回torch.Tensor.因为numpy数据是不带梯度的. 但是我的原始tensor的放在cuda上的, cuda的张量是不能直接转Tensor,所以

  • jquery获取input type=text中的值的各种方式(总结)

     实例如下: <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml> <html xmlns="http://www.w3.org/1999/xhtml"> <head> <title>JQuery获取文本框的值</title> <meta h

  • layui获取多选框中的值方法

    HTML: <input type="checkbox" name="standard" value="<?=$key;?>" <?=in_array($key, $standards) ? 'checked' : '';?> title="<?=$value;?>"> js: $("input:checkbox[name='standard']:checked&quo

  • javascript 获取iframe里页面中元素值的方法

    IE方法:document.frames['myFrame'].document.getElementById('test').value; 火狐方法:document.getElementById('myFrame').contentWindow.document.getElementById('test').value; IE.火狐方法: 复制代码 代码如下: function getValue(){ var tmp = ''; if(document.frames){ tmp += 'ie

  • vue获取或者改变vuex中的值方式

    目录 vue获取或改变vuex的值 store–>index.js 在页面中使用或者修改vuex中的值 监听vuex值变化实时改变 问题如图 思路 vue获取或改变vuex的值 store–>index.js import Vue from 'vue' import Vuex from 'vuex' Vue.use(Vuex) export default new Vuex.Store({ state: { isLogin:localStorage.getItem("isLogin&

  • 如何更优雅地获取spring boot yml中的值

    前言 偶然看到国外论坛有人在吐槽同事从配置文件获取值的方式,因此查阅了相关资料发现确实有更便于管理更优雅的获取方式. github demo地址: springboot-yml-value 1.什么是yml文件 application.yml取代application.properties,用来配置数据可读性更强,尤其是当我们已经制定了很多的层次结构配置的时候. 下面是一个非常基本的yml文件: server: url: http://localhost myapp: name: MyAppli

  • jquery获取input输入框中的值

    如何用javascript获取input输入框中的值,js/jq通过name.id.class获取input输入框中的value 先准备一段 HTML <input type="text" id="CN_NAME" name="CN_NAME" class="CN_NAME"> 一.jquery获取input文本框中的值 通过 name var name = $('input[name="CN_NAME&

  • js与jquery获取input输入框中的值实例讲解

    如何用javascript获取input输入框中的值,js/jq通过name.id.class获取input输入框中的value 先准备一段 HTML <input type="text" name"username" id="user" placeholder="用户名" class="uusr"><br> 一.jquery获取input文本框中的值 通过 name: $('inp

  • php获取数组中键值最大数组项的索引值 原创

    本文实例讲述了php获取数组中键值最大数组项的索引值的方法.分享给大家供大家参考.具体分析如下: 一.问题: 从给定数组中获取值最大的数组项的键值.用途如:获取班级得分最高的学生的姓名. 二.解决方法: <?php /* * Created on 2015-3-17 * Created by www.jb51.net */ $arr=array('tom'=>9,'jack'=>3,'kim'=>5,'hack'=>4); asort($arr); //print_r($ar

  • JS获取多维数组中相同键的值实现方法示例

    本文实例讲述了JS获取多维数组中相同键的值实现方法.分享给大家供大家参考,具体如下: <!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.0 Transitional//EN"> <HTML> <HEAD> <TITLE> Demo </TITLE> <META NAME="Keywords" CONTENT=""> <META NAME

随机推荐