Pytorch中的广播机制详解(Broadcast)

目录
  • 1. Pytorch中的广播机制
  • 2. 广播机制的理解
  • 3. 两个张量进行广播机制的条件
  • 4. 当两个张量满足可广播条件后
  • 5. 从空间上理解广播机制
  • 总结

1. Pytorch中的广播机制

如果一个Pytorch运算支持广播的话,那么就意味着传给这个运算的参数会被自动扩张成相同的size,在不复制数据的情况下就能进行运算,整个过程可以做到避免无用的复制,达到更高效的运算。

广播机制实际上是在运算过程中,去处理两个形状不同向量的一种手段。

pytorch中的广播机制和numpy中的广播机制一样, 因为都是数组的广播机制。

2. 广播机制的理解

以数组A和数组B的相加为例, 其余数学运算同理

核心:如果相加的两个数组的shape不同, 就会触发广播机制:

1)程序会自动执行操作使得A.shape==B.shape;

2)对应位置进行相加运算,结果的shape是:A.shape和B.shape对应位置的最大值,比如:A.shape=(1,9,4),B.shape=(15,1,4),那么A+B的shape是(15,9,4)

3. 两个张量进行广播机制的条件

3.1 两个张量都至少有一个维度

#像下面这种情况下就不行,因为x不满足这个条件。
x=torch.empty((0,))
y=torch.empty(2,2)

3.2 按从右往左顺序看两个张量的每一个维度,x和y每个对应着的两个维度都需要能够匹配上

什么情况下算是匹配上了?满足下面的条件就可以:

  • a.这两个维度的大小相等
  • b. 某个维度 一个张量有,一个张量没有
  • c.某个维度 一个张量有,一个张量也有但大小是1

如下举例:

x=torch.empty(5,3,4,1)
y=torch.empty( 3,1,1)

如上面代码中,首先将两个张量维度向右靠齐,从右往左看,两个张量第四维大小相等,都为1,满足上面条件a;第三个维度大小不相等,但第二个张量第三维大小为1,满足上面条件b;第二个维度大小相等都为3,满足上面条件a;第一个维度第一个张量有,第二个张量没有,满足上面条件b,因此两个张量每个维度都符合上面广播条件,因此可以进行广播

两个张量维度从右往左看,如果出现两个张量在某个维度位置上面,维度大小不相等,且两个维度大小没有一个是1,那么这两个张量一定不能进行广播。

4. 当两个张量满足可广播条件后

具体怎么进行广播

x=torch.empty(5,3,4,1)
y=torch.empty( 3,1,1)

如上面代码所示:

a. 首先第一步,将上面条件b的类型变成条件c的类型,也即是把第二个张量在缺失维度的位置上新增一个维度,维度大小为1,新增的维度如下面所示。

统一前:
x=torch.empty(5,3,4,1)
y=torch.empty( 3,1,1)
统一后:
x=torch.empty(5,3,4,1)
y=torch.empty(1,3,1,1)

b. 第二步,x、y对应维度不等的位置,把size为1的维度会被广播得和对应维度一样大,比如y中0维的1会变成5,y中2维的1会变成4,最后两个张量的维度大小变成一样,然后再进行张量运算,转变的维度如下所示

统一前:
x=torch.empty(5,3,4,1)
y=torch.empty(1,3,1,1)
统一后:
x=torch.empty(5,3,4,1)
y=torch.empty(5,3,4,1)

5. 从空间上理解广播机制

5.1 一维张量进行广播,b被自动广播得和a一样的维度大小,完成了张量相乘运算,如下图所示。

a = torch.tensor([1,2,3])
b = torch.tensor([2])
c = a*b
a,a.shape,b,b.shape,c,c.shape

输出结果如下:

(tensor([1, 2, 3]),
 torch.Size([3]),
 tensor([2]),
 torch.Size([1]),
 tensor([2, 4, 6]),
 torch.Size([3]))

5.1 二维张量进行广播,b被自动广播得和a一样的维度大小,完成了张量相加运算,如下图所示。

a = torch.tensor([[0],[10],[20],[30]])
b = torch.tensor([1,2,3])
c = a+b
a,a.shape,b,b.shape,c,c.shape

输出结果如下:

(tensor([[ 0],
         [10],
         [20],
         [30]]),
 torch.Size([4, 1]),
 tensor([1, 2, 3]),
 torch.Size([3]),
 tensor([[ 1,  2,  3],
         [11, 12, 13],
         [21, 22, 23],
         [31, 32, 33]]),
 torch.Size([4, 3]))

上面二维张量和一维张量相加运算进行广播过程为:a的形状是(4,1),b的形状是(3),如果a和b要匹配上,第一步给b新添一个维度,我们有:a的形状是(4,1),b的形状是(1,3);第二步二者各自把为1的维度进行广播,就如上图中那样进行广播,最后运算完成。

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • Python 的矩阵传播机制Broadcasting和矩阵运算

    目录 一.Python的矩阵传播机制(Broadcasting) 二.下面展示什么是python的传播机制 三.利用numpy的内置函数对矩阵进行操作 四.定义自己的函数来处理矩阵 五.总结 一.Python的矩阵传播机制(Broadcasting) 我们知道在深度学习中经常要操作各种矩阵(matrix) .回想一下,我们在操作数组(list)的时候,经常习惯于用**for循环(for-loop)**来对数组的每一个元素进行操作.例如: my_list = [1,2,3,4] new_list 

  • python的广播机制详解

    目录 为什么会有广播机制 在矩阵或向量相关运算中的广播机制 1.一般的运算 2.一个矩阵一个向量的情况 3.两个向量 4.矩阵乘法的广播机制 总结 为什么会有广播机制 python语言在设计的时候,就就考虑到用于两个运算的矩阵或向量维度不匹配的问题.例如,我们有矩阵A,让矩阵每个元素都加1,直接使用A+1,就可以完成目的.这其中就用到了python的广播机制,所以在很多python的第三方库中,都支持广播机制,例如Numpy.pytorch. 在矩阵或向量相关运算中的广播机制 1.一般的运算 假

  • 关于PyTorch 自动求导机制详解

    自动求导机制 从后向中排除子图 每个变量都有两个标志:requires_grad和volatile.它们都允许从梯度计算中精细地排除子图,并可以提高效率. requires_grad 如果有一个单一的输入操作需要梯度,它的输出也需要梯度.相反,只有所有输入都不需要梯度,输出才不需要.如果其中所有的变量都不需要梯度进行,后向计算不会在子图中执行. >>> x = Variable(torch.randn(5, 5)) >>> y = Variable(torch.rand

  • Java中的反射机制详解

    Java中的反射机制详解 反射,当时经常听他们说,自己也看过一些资料,也可能在设计模式中使用过,但是感觉对它没有一个较深入的了解,这次重新学习了一下,感觉还行吧! 一,先看一下反射的概念: 主要是指程序可以访问,检测和修改它本身状态或行为的一种能力,并能根据自身行为的状态和结果,调整或修改应用所描述行为的状态和相关的语义. 反射是Java中一种强大的工具,能够使我们很方便的创建灵活的代码,这些代码可以再运行时装配,无需在组件之间进行源代码链接.但是反射使用不当会成本很高! 看概念很晕的,继续往下

  • 基于python及pytorch中乘法的使用详解

    numpy中的乘法 A = np.array([[1, 2, 3], [2, 3, 4]]) B = np.array([[1, 0, 1], [2, 1, -1]]) C = np.array([[1, 0], [0, 1], [-1, 0]]) A * B : # 对应位置相乘 np.array([[ 1, 0, 3], [ 4, 3, -4]]) A.dot(B) : # 矩阵乘法 ValueError: shapes (2,3) and (2,3) not aligned: 3 (dim

  • PyTorch中permute的用法详解

    permute(dims) 将tensor的维度换位. 参数:参数是一系列的整数,代表原来张量的维度.比如三维就有0,1,2这些dimension. 例: import torch import numpy as np a=np.array([[[1,2,3],[4,5,6]]]) unpermuted=torch.tensor(a) print(unpermuted.size()) # --> torch.Size([1, 2, 3]) permuted=unpermuted.permute(

  • pytorch中的自定义数据处理详解

    pytorch在数据中采用Dataset的数据保存方式,需要继承data.Dataset类,如果需要自己处理数据的话,需要实现两个基本方法. :.getitem:返回一条数据或者一个样本,obj[index] = obj.getitem(index). :.len:返回样本的数量 . len(obj) = obj.len(). Dataset 在data里,调用的时候使用 from torch.utils import data import os from PIL import Image 数

  • Pytorch 中retain_graph的用法详解

    用法分析 在查看SRGAN源码时有如下损失函数,其中设置了retain_graph=True,其作用是什么? ############################ # (1) Update D network: maximize D(x)-1-D(G(z)) ########################### real_img = Variable(target) if torch.cuda.is_available(): real_img = real_img.cuda() z = V

  • PyTorch中的Variable变量详解

    一.了解Variable 顾名思义,Variable就是 变量 的意思.实质上也就是可以变化的量,区别于int变量,它是一种可以变化的变量,这正好就符合了反向传播,参数更新的属性. 具体来说,在pytorch中的Variable就是一个存放会变化值的地理位置,里面的值会不停发生片花,就像一个装鸡蛋的篮子,鸡蛋数会不断发生变化.那谁是里面的鸡蛋呢,自然就是pytorch中的tensor了.(也就是说,pytorch都是有tensor计算的,而tensor里面的参数都是Variable的形式).如果

  • Pytorch中.new()的作用详解

    一.作用 创建一个新的Tensor,该Tensor的type和device都和原有Tensor一致,且无内容. 二.使用方法 如果随机定义一个大小的Tensor,则新的Tensor有两种创建方法,如下: inputs = torch.randn(m, n) new_inputs = inputs.new() new_inputs = torch.Tensor.new(inputs) 三.具体代码 import torch rectangle_height = 1 rectangle_width

  • pytorch中index_select()的用法详解

    pytorch中index_select()的用法 index_select(input, dim, index) 功能:在指定的维度dim上选取数据,不如选取某些行,列 参数介绍 第一个参数input是要索引查找的对象 第二个参数dim是要查找的维度,因为通常情况下我们使用的都是二维张量,所以可以简单的记忆: 0代表行,1代表列 第三个参数index是你要索引的序列,它是一个tensor对象 刚开始学习pytorch,遇到了index_select(),一开始不太明白几个参数的意思,后来查了一

随机推荐