解析Pytorch中的torch.gather()函数

参数说明

以官方说明为例,gather()函数需要三个参数,输入input,维度dim,以及索引index

input必须为Tensor类型

dim为int类型,代表从哪个维度进行索引

index为LongTensor类型

举例说明

input=torch.tensor([[1,2,3],[4,5,6]]) #作为输入

index1=torch.tensor([[0,1,1],[0,1,1]]) #作为索引矩阵

# dim=0时,按列进行索引
print (torch.gather(input,dim=0,index=index1))

# dim=1时,按行进行索引
print (torch.gather(input,dim=1,index=index1))

结果如下图所示:

# 按列进行索引
tensor([[1, 5, 6],
        [4, 2, 6]])

# 按行进行索引
tensor([[1, 2, 2],
        [5, 4, 5]])

画图说明

官方文档

def gather(self, input, dim, index, *args, **kwargs): 

        For a 3-D tensor the output is specified by::

            out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
            out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
            out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2        

        Args:
            input (Tensor): the source tensor
            dim (int): the axis along which to index
            index (LongTensor): the indices of elements to gather     

        Example::

            >>> t = torch.tensor([[1, 2], [3, 4]])
            >>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
            tensor([[ 1,  1],
                    [ 4,  3]])

到此这篇关于Pytorch中的torch.gather()函数的文章就介绍到这了,更多相关Pytorch torch.gather()函数内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • 浅谈Pytorch中的torch.gather函数的含义

    pytorch中的gather函数 pytorch比tensorflow更加编程友好,所以准备用pytorch试着做最近要做的一些实验. 立个flag开始学习pytorch,新开一个分类整理学习pytorch中的一些踩到的泥坑. 今天刚开始接触,读了一下documentation,写一个一开始每太搞懂的函数gather b = torch.Tensor([[1,2,3],[4,5,6]]) print b index_1 = torch.LongTensor([[0,1],[2,0]]) ind

  • 解析Pytorch中的torch.gather()函数

    参数说明 以官方说明为例,gather()函数需要三个参数,输入input,维度dim,以及索引index input必须为Tensor类型 dim为int类型,代表从哪个维度进行索引 index为LongTensor类型 举例说明 input=torch.tensor([[1,2,3],[4,5,6]]) #作为输入 index1=torch.tensor([[0,1,1],[0,1,1]]) #作为索引矩阵 # dim=0时,按列进行索引 print (torch.gather(input,

  • 详解pytorch中squeeze()和unsqueeze()函数介绍

    squeeze的用法主要就是对数据的维度进行压缩或者解压. 先看torch.squeeze() 这个函数主要对数据的维度进行压缩,去掉维数为1的的维度,比如是一行或者一列这种,一个一行三列(1,3)的数去掉第一个维数为一的维度之后就变成(3)行.squeeze(a)就是将a中所有为1的维度删掉.不为1的维度没有影响.a.squeeze(N) 就是去掉a中指定的维数为一的维度.还有一种形式就是b=torch.squeeze(a,N) a中去掉指定的定的维数为一的维度. 再看torch.unsque

  • pytorch中with torch.no_grad():的用法实例

    目录 1.关于with 2.关于withtorch.no_grad(): 附:pytorch使用模型测试使用withtorch.no_grad(): 总结 1.关于with with是python中上下文管理器,简单理解,当要进行固定的进入,返回操作时,可以将对应需要的操作,放在with所需要的语句中.比如文件的写入(需要打开关闭文件)等. 以下为一个文件写入使用with的例子. with open (filename,'w') as sh: sh.write("#!/bin/bash\n&qu

  • PyTorch中的torch.cat简单介绍

    目录 1.toych简单介绍 2.张量Tensors 3.torch.cat 1.toych简单介绍 包torch包含了多维疑是的数据结构及基于其上的多种数学操作. torch包含了多维张量的数据结构以及基于其上的多种数学运算.此外,它也提供了多种实用工具,其中一些可以更有效地对张量和任意类型进行序列化的工具. 它具有CUDA的对应实现,可以在NVIDIA GPU上进行张量运算(计算能力>=3.0) 2. 张量Tensors torch.is_tensor(obj):如果obj是一个pytorc

  • pytorch中的torch.nn.Conv2d()函数图文详解

    目录 一.官方文档介绍 二.torch.nn.Conv2d()函数详解 参数dilation——扩张卷积(也叫空洞卷积) 参数groups——分组卷积 总结 一.官方文档介绍 官网 nn.Conv2d:对由多个输入平面组成的输入信号进行二维卷积 二.torch.nn.Conv2d()函数详解 参数详解 torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1,

  • Pytorch中如何调用forward()函数

    目录 Pytorch调用forward()函数 Pytorch函数调用的问题和源码解读 总结 Pytorch调用forward()函数 Module类是nn模块里提供的一个模型构造类,是所有神经网络模块的基类,我们可以继承它来定义我们想要的模型. 下面继承Module类构造本节开头提到的多层感知机. 这里定义的MLP类重载了Module类的__init__函数和forward函数. 它们分别用于创建模型参数和定义前向计算. 前向计算也即正向传播. import torch from torch

  • 深入解析JavaScript中的立即执行函数

    它是什么 在 JavaScript 里,每个函数,当被调用时,都会创建一个新的执行上下文.因为在函数里定义的变量和函数是唯一在内部被访问的变量,而不是在外部被访问的变量,当调用函数时,函数提供的上下文提供了一个非常简单的方法创建私有变量. function makeCounter() { var i = 0; return function(){ console.log(++i); }; } //记住:`counter`和`counter2`都有他们自己的变量 `i` var counter =

  • 解析C++中的字符串处理函数和指针

    C++字符串处理函数 字符串连接函数 strcat 其函数原型为 strcat(char[],const char[]); strcat是string catenate(字符串连接)的缩写.该函数有两个字符数组的参数,函数的作用是:将第二个字符数组中的字符串连接到前面字符数组的字符串的后面.第二个字符数组被指定为const,以保证该数组中的内容不会在函数调用期间修改.连接后的字符串放在第一个字符数组中,函数调用后得到的函数值,就是第一个字符数组的地址.例如: char str1[30]=″Peo

  • 解析php中如何调用用户自定义函数

    先放上来别人的例子吧:call_user_func函数类似于一种特别的调用函数的方法,使用方法如下:    复制代码 代码如下: function a($b,$c)    {    echo $b;    echo $c;    }    call_user_func('a', "111","222");    call_user_func('a', "333","444");    //显示 111 222 333 444 

随机推荐