Pytorch中的gather使用方法

官方说明

gather可以对一个Tensor进行聚合,声明为:torch.gather(input, dim, index, out=None) → Tensor

一般来说有三个参数:输入的变量input、指定在某一维上聚合的dim、聚合的使用的索引index,输出为Tensor类型的结果(index必须为LongTensor类型)。

#参数介绍:
input (Tensor) – The source tensor
dim (int) – The axis along which to index
index (LongTensor) – The indices of elements to gather
out (Tensor, optional) – Destination tensor
#当输入为三维时的计算过程:
out[i][j][k] = input[index[i][j][k]][j][k]  # dim=0
out[i][j][k] = input[i][index[i][j][k]][k]  # dim=1
out[i][j][k] = input[i][j][index[i][j][k]]  # dim=2
#样例:
t = torch.Tensor([[1,2],[3,4]])
torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
#    1  1
#    4  3
#[torch.FloatTensor of size 2x2]

实验

用下面的代码在二维上做测试,以便更好地理解

t = torch.Tensor([[1,2,3],[4,5,6]])
index_a = torch.LongTensor([[0,0],[0,1]])
index_b = torch.LongTensor([[0,1,1],[1,0,0]])
print(t)
print(torch.gather(t,dim=1,index=index_a))
print(torch.gather(t,dim=0,index=index_b))

输出为:

>>tensor([[1., 2., 3.],
        [4., 5., 6.]])
>>tensor([[1., 1.],
        [4., 5.]])
>>tensor([[1., 5., 6.],
        [4., 2., 3.]])

由于官网给的计算过程不太直观,下面给出较为直观的解释:

对于index_a,dim为1表示在第二个维度上进行聚合,索引为列号,[[0,0],[0,1]]表示结果的第一行取原数组第一行列号为[0,0]的数,也就是[1,1],结果的第二行取原数组第二行列号为[0,1]的数,也就是[4,5],这样就得到了输出的结果[[1,1],[4,5]]。

对于index_b,dim为0表示在第一个维度上进行聚合,索引为行号,[[0,1,1],[1,0,0]]表示结果的第一行第d(d=0,1,2)列取原数组第d列行号为[0,1,1]的数,也就是[1,5,6],类似的,结果的第二行第d列取原数组第d列行号为[1,0,0]的数,也就是[4,2,3],这样就得到了输出的结果[[1,5,6],[4,2,3]]

接下来以index_a为例直接用官网的式子计算一遍加深理解:

output[0,0] = input[0,index[0,0]]  #1 = input[0,0]
output[0,1] = input[0,index[0,1]]  #1 = input[0,0]
output[1,0] = input[1,index[1,0]]  #4 = input[1,0]
output[1,1] = input[1,index[1,1]]  #5 = input[1,1]

以下两种写法得到的结果是一样的:

r1 = torch.gather(t,dim=1,index=index_a)

r2 = t.gather(1,index_a)

补充:Pytorch中的torch.gather函数的个人理解

最近在学习pytorch时遇到gather函数,开始没怎么理解,后来查阅网上相关资料后大概明白了原理。

gather()函数

在pytorch中,gather()函数的作用是将数据从input中按index提出,我们看gather函数的的官方文档说明如下:

torch.gather(input, dim, index, out=None) → Tensor
    Gathers values along an axis specified by dim.
    For a 3-D tensor the output is specified by:

    out[i][j][k] = input[index[i][j][k]][j][k]  # dim=0
    out[i][j][k] = input[i][index[i][j][k]][k]  # dim=1
    out[i][j][k] = input[i][j][index[i][j][k]]  # dim=2

    Parameters: 

        input (Tensor) – The source tensor
        dim (int) – The axis along which to index
        index (LongTensor) – The indices of elements to gather
        out (Tensor, optional) – Destination tensor

    Example:

    >>> t = torch.Tensor([[1,2],[3,4]])
    >>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
     1  1
     4  3
    [torch.FloatTensor of size 2x2]

可以看出,在gather函数中我们用到的主要有三个参数:

1)input:输入

2)dim:维度,常用的为0和1

3)index:索引位置

贴一段代码举例说明:

a=t.arange(0,16).view(4,4)
print(a)

index_1=t.LongTensor([[3,2,1,0]])
b=a.gather(0,index_1)
print(b)

index_2=t.LongTensor([[0,1,2,3]]).t()#tensor转置操作:(a)T=a.t()
c=a.gather(1,index_2)
print(c)

输出如下:

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
       
tensor([[12,  9,  6,  3]])

tensor([[ 0],
        [ 5],
        [10],
        [15]])

在gather中,我们是通过index对input进行索引把对应的数据提取出来的,而dim决定了索引的方式。

在上面的例子中,a是一个4×4矩阵:

1)当维度dim=0,索引index_1为[3,2,1,0]时,此时可将a看成1×4的矩阵,通过index_1对a每列进行行索引:第一列第四行元素为12,第二列第三行元素为9,第三列第二行元素为6,第四列第一行元素为3,即b=[12,9,6,3];

2)当维度dim=1,索引index_2为[0,1,2,3]T时,此时可将a看成4×1的矩阵,通过index_1对a每行进行列索引:第一行第一列元素为0,第二行第二列元素为5,第三行第三列元素为10,第四行第四列元素为15,即c=[0,5,10,15]T;

总结

gather函数在提取数据时主要靠dim和index这两个参数,dim=1时将input看为n×1阶矩阵,index看为k×1阶矩阵,取index每行元素对input中每行进行列索引(如:index某行为[1,3,0],对应的input行元素为[9,8,7,6],提取后的结果为[8,6,9]);

同理,dim=0时将input看为1×n阶矩阵,index看为1×k阶矩阵,取index每列元素对input中每列进行行索引。

gather函数提取后的矩阵阶数和对应的index阶数相同。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • 浅谈Pytorch中的torch.gather函数的含义

    pytorch中的gather函数 pytorch比tensorflow更加编程友好,所以准备用pytorch试着做最近要做的一些实验. 立个flag开始学习pytorch,新开一个分类整理学习pytorch中的一些踩到的泥坑. 今天刚开始接触,读了一下documentation,写一个一开始每太搞懂的函数gather b = torch.Tensor([[1,2,3],[4,5,6]]) print b index_1 = torch.LongTensor([[0,1],[2,0]]) ind

  • PyTorch中的拷贝与就地操作详解

    前言 PyTroch中我们经常使用到Numpy进行数据的处理,然后再转为Tensor,但是关系到数据的更改时我们要注意方法是否是共享地址,这关系到整个网络的更新.本篇就In-palce操作,拷贝操作中的注意点进行总结. In-place操作 pytorch中原地操作的后缀为_,如.add_()或.scatter_(),就地操作是直接更改给定Tensor的内容而不进行复制的操作,即不会为变量分配新的内存.Python操作类似+=或*=也是就地操作.(我加了我自己~) 为什么in-place操作可以

  • Pytorch高阶OP操作where,gather原理

    PyTorch是一个非常有可能改变深度学习领域前景的Python库.我尝试使用了几星期PyTorch,然后被它的易用性所震惊,在我使用过的各种深度学习库中,PyTorch是最灵活.最容易掌握的. 一.where 1)torch.where(condition, x, y) # condition是条件,满足条件就返回x,不满足就返回y 2)特点,相比for循环的优点是:可以布置在GPU上运行 二.gather 1)官方解释:根据指定的维度和索引值来筛选值 2)举例 以上就是本文的全部内容,希望对

  • Pytorch中的gather使用方法

    官方说明 gather可以对一个Tensor进行聚合,声明为:torch.gather(input, dim, index, out=None) → Tensor 一般来说有三个参数:输入的变量input.指定在某一维上聚合的dim.聚合的使用的索引index,输出为Tensor类型的结果(index必须为LongTensor类型). #参数介绍: input (Tensor) – The source tensor dim (int) – The axis along which to ind

  • 对pytorch中的梯度更新方法详解

    背景 使用pytorch时,有一个yolov3的bug,我认为涉及到学习率的调整.收集到tencent yolov3和mxnet开源的yolov3,两个优化器中的学习率设置不一样,而且使用GPU数目和batch的更新也不太一样.据此,我简单的了解了下pytorch的权重梯度的更新策略,看看能否一窥究竟. 对代码说明 共三个实验,分布写在代码中的(一)(二)(三)三个地方.运行实验时注释掉其他两个 实验及其结果 实验(三): 不使用zero_grad()时,grad累加在一起,官网是使用accum

  • 基于pytorch中的Sequential用法说明

    class torch.nn.Sequential(* args) 一个时序容器.Modules 会以他们传入的顺序被添加到容器中.当然,也可以传入一个OrderedDict. 为了更容易的理解如何使用Sequential, 下面给出了一个例子: # Example of using Sequential model = nn.Sequential( nn.Conv2d(1,20,5), nn.ReLU(), nn.Conv2d(20,64,5), nn.ReLU() ) # Example o

  • 在Pytorch中使用样本权重(sample_weight)的正确方法

    step: 1.将标签转换为one-hot形式. 2.将每一个one-hot标签中的1改为预设样本权重的值 即可在Pytorch中使用样本权重. eg: 对于单个样本:loss = - Q * log(P),如下: P = [0.1,0.2,0.4,0.3] Q = [0,0,1,0] loss = -Q * np.log(P) 增加样本权重则为loss = - Q * log(P) *sample_weight P = [0.1,0.2,0.4,0.3] Q = [0,0,sample_wei

  • pytorch中如何使用DataLoader对数据集进行批处理的方法

    最近搞了搞minist手写数据集的神经网络搭建,一个数据集里面很多个数据,不能一次喂入,所以需要分成一小块一小块喂入搭建好的网络. pytorch中有很方便的dataloader函数来方便我们进行批处理,做了简单的例子,过程很简单,就像把大象装进冰箱里一共需要几步? 第一步:打开冰箱门. 我们要创建torch能够识别的数据集类型(pytorch中也有很多现成的数据集类型,以后再说). 首先我们建立两个向量X和Y,一个作为输入的数据,一个作为正确的结果: 随后我们需要把X和Y组成一个完整的数据集,

  • pytorch中的embedding词向量的使用方法

    Embedding 词嵌入在 pytorch 中非常简单,只需要调用 torch.nn.Embedding(m, n) 就可以了,m 表示单词的总数目,n 表示词嵌入的维度,其实词嵌入就相当于是一个大矩阵,矩阵的每一行表示一个单词. emdedding初始化 默认是随机初始化的 import torch from torch import nn from torch.autograd import Variable # 定义词嵌入 embeds = nn.Embedding(2, 5) # 2

  • pytorch在fintune时将sequential中的层输出方法,以vgg为例

    有时候我们在fintune时发现pytorch把许多层都集合在一个sequential里,但是我们希望能把中间层的结果引出来做下一步操作,于是我自己琢磨了一个方法,以vgg为例,有点僵硬哈! 首先pytorch自带的vgg16模型的网络结构如下: VGG( (features): Sequential( (0): Conv2d (3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace) (2): Co

  • pytorch中tensor的合并与截取方法

    合并: torch.cat(inputs=(a, b), dimension=1) e.g. x = torch.cat((x,y), 0) 沿x轴合并 截取: x[:, 2:4] 以上这篇pytorch中tensor的合并与截取方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们.

  • 在Pytorch中计算卷积方法的区别详解(conv2d的区别)

    在二维矩阵间的运算: class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True) 对由多个特征平面组成的输入信号进行2D的卷积操作.详解 torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1)

随机推荐