深入浅析Pytorch中stack()方法

目录
  • 1. 概念
  • 2. 参数
  • 3. 举例
    • 3.1 四个shape为[3, 3]的张量
      • 3.1.1 dim=0的情况下,直接来看结果。
      • 3.1.2 dim=1的情况下
      • 3.1.2 dim=2的情况下
      • 3.1.3 总结
    • 3.2 7个shape为[5, 7, 4, 2]的张量
  • 4. 理解

Torch.stack()

1. 概念

在一个新的维度上连接一个张量序列

2. 参数

  • tensors (sequence)需要连接的张量序列
  • dim (int)在第dim个维度上连接

注意输入的张量shape要完全一致,且dim必须小于len(tensors)。

3. 举例

3.1 四个shape为[3, 3]的张量

a = torch.Tensor([[1,2,3],[4,5,6],[7,8,9]])
 b = torch.Tensor([[10,20,30],[40,50,60],[70,80,90]])
 c = torch.Tensor([[100,200,300],[400,500,600],[700,800,900]])
 d = torch.Tensor([[1000,2000,3000],[4000,5000,6000],[7000,8000,9000]])

以下面这4个张量,每个张量shape为[3, 3]。

3.1.1 dim=0的情况下,直接来看结果。

torch.stack((a,b,c,d),dim=0)

此时在第0个维度上连接,新张量的shape可以发现为[4, 3, 3],4代表在第0个维度有4项。

观察可以得知:即初始的四个张量,即a、b、c、d四个初始张量。

可以理解为新张量的第0个维度上连接a、b、c、d。

3.1.2 dim=1的情况下

torch.stack((a,b,c,d),dim=1)

此时在第1个维度上连接,新张量的shape可以发现为[3,4, 3],4代表在第1个维度有4项。

观察可以得知:

  • 新张量[0][0]为a[0],[0][1]为b[0],[0][2]为c[0],[0][3]为d[0]
  • 新张量[1][0]为a[1],[1][1]为b[1],[1][2]为c[1],[1][3]为d[1]
  • 新张量[2][0]为a[2],[2][1]为b[2],[2][2]为c[2],[2][3]为d[2]

可以理解为新张量的第1个维度上连接a、b、c、d的第0个维度单位,具体地说,在新张量[i]中连接a[i]、b[i]、c[i]、d[i],即将a[i]赋给新张量[i][0]、b[i]赋给新张量[i][1]、c[i]赋给新张量[i][2]、d[i]赋给新张量[i][3]。

3.1.2 dim=2的情况下

此时在第2个维度上连接,新张量的shape可以发现为[3,3,4],4代表在第2个维度有4项。

观察可以得知:

新张量[0][0][0]为a[0][0],[0][0][1]为b[0][0],[0][0][2]为c[0][0],[0][0][3]为d[0][0]
新张量[0][1][0]为a[0][1],[0][1][1]为b[0][1],[0][1][2]为c[0][1],[0][1][3]为d[0][1]
新张量[0][2][0]为a[0][2],[0][2][1]为b[0][2],[0][2][2]为c[0][2],[0][2][3]为d[0][2]
新张量[1][0][0]为a[1][0],[1][0][1]为b[1][0],[1][0][2]为c[1][0],[1][0][3]为d[1][0]
新张量[1][1][0]为a[1][1],[1][1][1]为b[1][1],[1][1][2]为c[1][1],[1][1][3]为d[1][1]
新张量[1][2][0]为a[1][2],[1][2][1]为b[1][2],[1][2][2]为c[1][2],[1][2][3]为d[1][2]
新张量[2][0][0]为a[2][0],[2][0][1]为b[2][0],[2][0][2]为c[2][0],[2][0][3]为d[2][0]
新张量[2][1][0]为a[2][1],[2][1][1]为b[2][1],[2][1][2]为c[2][1],[2][1][3]为d[2][1]
新张量[2][2][0]为a[2][2],[2][2][1]为b[2][2],[2][2][2]为c[2][2],[2][2][3]为d[2][2]

可以理解为新张量的第2个维度上连接a、b、c、d的第1个维度的单位,具体地说,在新张量[i][j]中连接a[i][j]、b[i][j]、c[i][j]、d[i][]j。

3.1.3 总结

通过dim=0、1、2的情况,可以总结并推涨出规律:

假设有n个[x,y]的张量,当dim=z时。新张量在第z个维度上连接n个张量第z-1维度的单位,具体来说,新张量[i][i+1]..[i+z-1]中依次连接n个向量[i][i+1]..[i+z-1]。

3.2 7个shape为[5, 7, 4, 2]的张量

a1 = torch.rand([5, 7, 4, 3])
a2 = a1 + 1
a3 = a2 + 1
a4 = a3 + 1
a5 = a4 + 1
a6 = a5 + 1
a7 = a6 + 1

假设dim=3时连接

test = torch.stack((a1, a2, a3, a4, a5, a6, a7), dim=3)

7个张量在第3个维度连接后形成的新张量赋为test,test的shape为[5, 7, 4,7, 3],代表在第3个维度有7项。

随机(在新张量[0][0][0]到新张量[4][6][3]区间内)查看一个新张量第3维度上的单位:

a = test[0][1][2]

再根据总结的规律,将7个向量中的[0][1][2]连接起来,再次查看,验证了规律。

b = torch.zeros(0)
for i in (a1, a2, a3, a4, a5, a6, a7):
    b = torch.cat((b, i[0][1][2]), dim=0)

4. 理解

通过shape来看,假设shape为[a, b, c... z],有n个shape相同的张量,在dim=x时连接n个张量,可以得到新张量,shape为[a, b, c, ... n, ...z],其中n所在维度即为第x个维度。

然后即可通过新张量[i][i+1]..[i+x-1]看作索引,对应的数据为n个张量[i][i+1][i+x-1]按顺序连接。

到此这篇关于Pytorch中stack()方法的总结及理解的文章就介绍到这了,更多相关Pytorch中stack()方法的总结及理解内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • 对PyTorch torch.stack的实例讲解

    不是concat的意思 import torch a = torch.ones([1,2]) b = torch.ones([1,2]) torch.stack([a,b],1) (0 ,.,.) = 1 1 1 1 [torch.FloatTensor of size 1x2x2] 以上这篇对PyTorch torch.stack的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们.

  • 浅谈pytorch中stack和cat的及to_tensor的坑

    初入计算机视觉遇到的一些坑 1.pytorch中转tensor x=np.random.randint(10,100,(10,10,10)) x=TF.to_tensor(x) print(x) 这个函数会对输入数据进行自动归一化,比如有时候我们需要将0-255的图片转为numpy类型的数据,则会自动转为0-1之间 2.stack和cat之间的差别 stack x=torch.randn((1,2,3)) y=torch.randn((1,2,3)) z=torch.stack((x,y))#默

  • 聊聊Pytorch torch.cat与torch.stack的区别

    torch.cat()函数可以将多个张量拼接成一个张量.torch.cat()有两个参数,第一个是要拼接的张量的列表或是元组:第二个参数是拼接的维度. torch.cat()的示例如下图1所示 图1 torch.cat() torch.stack()函数同样有张量列表和维度两个参数.stack与cat的区别在于,torch.stack()函数要求输入张量的大小完全相同,得到的张量的维度会比输入的张量的大小多1,并且多出的那个维度就是拼接的维度,那个维度的大小就是输入张量的个数. torch.st

  • 深入浅析Pytorch中stack()方法

    目录 1. 概念 2. 参数 3. 举例 3.1 四个shape为[3, 3]的张量 3.1.1 dim=0的情况下,直接来看结果. 3.1.2 dim=1的情况下 3.1.2 dim=2的情况下 3.1.3 总结 3.2 7个shape为[5, 7, 4, 2]的张量 4. 理解 Torch.stack() 1. 概念 在一个新的维度上连接一个张量序列 2. 参数 tensors (sequence)需要连接的张量序列 dim (int)在第dim个维度上连接 注意输入的张量shape要完全一

  • python DataFrame中stack()方法、unstack()方法和pivot()方法浅析

    目录 1.stack() 2. unstack() 3. pivot() 总结 1.stack() stack()用于将列索引转换为最内层的行索引,这样叙述比较抽象,看示例就容易理解啦: 准备一组数据,给其设置双索引. import pandas as pd data = [['A类', 'a1', 123, 224, 254], ['A类', 'a2', 234, 135, 444], ['A类', 'a3', 345, 241, 324], ['B类', 'b1', 112, 412, 46

  • 浅析Javascript中bind()方法的使用与实现

    在讨论bind()方法之前我们先来看一道题目: var altwrite = document.write;  altwrite("hello");  //1.以上代码有什么问题 //2.正确操作是怎样的 //3.bind()方法怎么实现 对于上面这道题目,答案并不是太难,主要考点就是this指向的问题,altwrite()函数改变this的指向global或window对象,导致执行时提示非法调用异常,正确的方案就是使用bind()方法: altwrite.bind(document

  • 浅析Java中clone()方法浅克隆与深度克隆

    现在Clone已经不是一个新鲜词语了,伴随着"多莉"的产生这个词语确实很"火"过一阵子,在Java中也有这么一个概念,它可以让我们很方便的"制造"出一个对象的副本来,下面来具体看看Java中的Clone机制是如何工作的?      1. Clone&Copy 假设现在有一个Employee对象,Employee tobby =new Employee("CMTobby",5000),通 常我们会有这样的赋值Employ

  • 浅析JAVA中toString方法的作用

    因为它是Object里面已经有了的方法,而所有类都是继承Object,所以"所有对象都有这个方法". 它通常只是为了方便输出,比如System.out.println(xx),括号里面的"xx"如果不是String类型的话,就自动调用xx的toString()方法 总而言之,它只是sun公司开发java的时候为了方便所有类的字符串操作而特意加入的一个方法 回答补充:写这个方法的用途就是为了方便操作,所以在文件操作里面可用可不用例子1: 复制代码 代码如下: publ

  • 浅析PyTorch中nn.Module的使用

    torch.nn.Modules 相当于是对网络某种层的封装,包括网络结构以及网络参数和一些操作 torch.nn.Module 是所有神经网络单元的基类 查看源码 初始化部分: def __init__(self): self._backend = thnn_backend self._parameters = OrderedDict() self._buffers = OrderedDict() self._backward_hooks = OrderedDict() self._forwa

  • 浅析PyTorch中nn.Linear的使用

    查看源码 Linear 的初始化部分: class Linear(Module): ... __constants__ = ['bias'] def __init__(self, in_features, out_features, bias=True): super(Linear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(to

  • 在Pytorch中使用样本权重(sample_weight)的正确方法

    step: 1.将标签转换为one-hot形式. 2.将每一个one-hot标签中的1改为预设样本权重的值 即可在Pytorch中使用样本权重. eg: 对于单个样本:loss = - Q * log(P),如下: P = [0.1,0.2,0.4,0.3] Q = [0,0,1,0] loss = -Q * np.log(P) 增加样本权重则为loss = - Q * log(P) *sample_weight P = [0.1,0.2,0.4,0.3] Q = [0,0,sample_wei

  • pytorch中如何使用DataLoader对数据集进行批处理的方法

    最近搞了搞minist手写数据集的神经网络搭建,一个数据集里面很多个数据,不能一次喂入,所以需要分成一小块一小块喂入搭建好的网络. pytorch中有很方便的dataloader函数来方便我们进行批处理,做了简单的例子,过程很简单,就像把大象装进冰箱里一共需要几步? 第一步:打开冰箱门. 我们要创建torch能够识别的数据集类型(pytorch中也有很多现成的数据集类型,以后再说). 首先我们建立两个向量X和Y,一个作为输入的数据,一个作为正确的结果: 随后我们需要把X和Y组成一个完整的数据集,

随机推荐