pytorch中常用的乘法运算及相关的运算符(@和*)

目录
  • 前言
  • 1、torch.mm
  • 2、torch.bmm
  • 3、torch.mul
  • 4、torch.mv
  • 5、torch.matmul
  • 6、@运算符
  • 7、*运算符
  • 附:二维矩阵乘法
  • 总结

前言

这里总结一下pytorch常用的乘法运算以及相关的运算符(@、*)。

总结放前面:

torch.mm : 用于两个矩阵(不包括向量)的乘法。如维度为(l,m)和(m,n)相乘

torch.bmm : 用于带batch的三维向量的乘法。如维度为(b,l,m)和(b,m,n)相乘

torch.mul : 用于两个同维度矩阵的逐像素点相乘(点乘)。如维度为(l,m)和(l,m)相乘

torch.mv : 用于矩阵和向量之间的乘法(矩阵在前,向量在后)。如维度为(l,m)和(m)相乘,结果的维度为(l)。

torch.matmul : 用于两个张量(后两维满足矩阵乘法的维度)相乘或者是矩阵与向量间的乘法,因为其具有广播机制(broadcasting,自动补充维度)。如维度为(b,l,m)和(b,m,n);(l,m)和(b,m,n);(b,c,l,m)和(b,c,m,n);(l,m)和(m)相乘等。【其作用包含torch.mm、torch.bmm和torch.mv】

@运算符 : 其作用类似于torch.matmul。

*运算符 : 其作用类似于torch.mul。

1、torch.mm

import torch
a = torch.ones(1, 2)
print(a)
b = torch.ones(2, 3)
print(b)
output = torch.mm(a, b)
print(output)
print(output.size())
"""
tensor([[1., 1.]])
tensor([[1., 1., 1.],
        [1., 1., 1.]])
tensor([[2., 2., 2.]])
torch.Size([1, 3])
"""

2、torch.bmm

a = torch.randn(2, 1, 2)
print(a)
b = torch.randn(2, 2, 3)
print(b)
output = torch.bmm(a, b)
print(output)
print(output.size())
"""
tensor([[[-0.1187,  0.2110]],

        [[ 0.7463, -0.6136]]])
tensor([[[-0.1186,  1.5565,  1.3662],
         [ 1.0199,  2.4644,  1.1630]],

        [[-1.9483, -1.6258, -0.4654],
         [-0.1424,  1.3892,  0.7559]]])
tensor([[[ 0.2293,  0.3352,  0.0832]],

        [[-1.3666, -2.0657, -0.8111]]])
torch.Size([2, 1, 3])
"""

3、torch.mul

a = torch.ones(2, 3) * 2
print(a)
b = torch.randn(2, 3)
print(b)
output = torch.mul(a, b)
print(output)
print(output.size())
"""
tensor([[2., 2., 2.],
        [2., 2., 2.]])
tensor([[-0.1187,  0.2110,  0.7463],
        [-0.6136, -0.1186,  1.5565]])
tensor([[-0.2375,  0.4220,  1.4925],
        [-1.2271, -0.2371,  3.1130]])
torch.Size([2, 3])
"""

4、torch.mv

mat = torch.randn(3, 4)
print(mat)
vec = torch.randn(4)
print(vec)
output = torch.mv(mat, vec)
print(output)
print(output.size())
print(torch.mm(mat, vec.unsqueeze(1)).squeeze(1))
"""
tensor([[-0.1187,  0.2110,  0.7463, -0.6136],
        [-0.1186,  1.5565,  1.3662,  1.0199],
        [ 2.4644,  1.1630, -1.9483, -1.6258]])
tensor([-0.4654, -0.1424,  1.3892,  0.7559])
tensor([ 0.5982,  2.5024, -5.2481])
torch.Size([3])
tensor([ 0.5982,  2.5024, -5.2481])
"""

5、torch.matmul

# 其作用包含torch.mm、torch.bmm和torch.mv。其他类似,不一一举例。
a = torch.randn(2, 1, 2)
print(a)
b = torch.randn(2, 2, 3)
print(b)
output = torch.bmm(a, b)
print(output)
output1 = torch.matmul(a, b)
print(output1)
print(output1.size())
"""
tensor([[[-0.1187,  0.2110]],

        [[ 0.7463, -0.6136]]])
tensor([[[-0.1186,  1.5565,  1.3662],
         [ 1.0199,  2.4644,  1.1630]],

        [[-1.9483, -1.6258, -0.4654],
         [-0.1424,  1.3892,  0.7559]]])
tensor([[[ 0.2293,  0.3352,  0.0832]],

        [[-1.3666, -2.0657, -0.8111]]])
tensor([[[ 0.2293,  0.3352,  0.0832]],

        [[-1.3666, -2.0657, -0.8111]]])
torch.Size([2, 1, 3])
"""
# 维度为(b,l,m)和(b,m,n);(l,m)和(b,m,n);(b,c,l,m)和(b,c,m,n);(l,m)和(m)等
a = torch.randn(2, 3, 4)
b = torch.randn(2, 4, 5)
print(torch.matmul(a, b).size())
a = torch.randn(3, 4)
b = torch.randn(2, 4, 5)
print(torch.matmul(a, b).size())
a = torch.randn(2, 3, 3, 4)
b = torch.randn(2, 3, 4, 5)
print(torch.matmul(a, b).size())
a = torch.randn(2, 3)
b = torch.randn(3)
print(torch.matmul(a, b).size())
"""
torch.Size([2, 3, 5])
torch.Size([2, 3, 5])
torch.Size([2, 3, 3, 5])
torch.Size([2])
"""

6、@运算符

# @运算符:其作用类似于torch.matmul
a = torch.randn(2, 3, 4)
b = torch.randn(2, 4, 5)
print(torch.matmul(a, b).size())
print((a @ b).size())
a = torch.randn(3, 4)
b = torch.randn(2, 4, 5)
print(torch.matmul(a, b).size())
print((a @ b).size())
a = torch.randn(2, 3, 3, 4)
b = torch.randn(2, 3, 4, 5)
print(torch.matmul(a, b).size())
print((a @ b).size())
a = torch.randn(2, 3)
b = torch.randn(3)
print(torch.matmul(a, b).size())
print((a @ b).size())
"""
torch.Size([2, 3, 5])
torch.Size([2, 3, 5])
torch.Size([2, 3, 5])
torch.Size([2, 3, 5])
torch.Size([2, 3, 3, 5])
torch.Size([2, 3, 3, 5])
torch.Size([2])
torch.Size([2])
"""

7、*运算符

# *运算符:其作用类似于torch.mul
a = torch.ones(2, 3) * 2
print(a)
b = torch.ones(2, 3) * 3
print(b)
output = torch.mul(a, b)
print(output)
print(output.size())
output1 = a * b
print(output1)
print(output1.size())
"""
tensor([[2., 2., 2.],
        [2., 2., 2.]])
tensor([[3., 3., 3.],
        [3., 3., 3.]])
tensor([[6., 6., 6.],
        [6., 6., 6.]])
torch.Size([2, 3])
tensor([[6., 6., 6.],
        [6., 6., 6.]])
torch.Size([2, 3])
"""

附:二维矩阵乘法

神经网络中包含大量的 2D 张量矩阵乘法运算,而使用 torch.matmul 函数比较复杂,因此 PyTorch 提供了更为简单方便的 torch.mm(input, other, out = None) 函数。下表是 torch.matmul 函数和 torch.mm 函数的简单对比。

torch.matmul 函数支持广播,主要指的是当参与矩阵乘积运算的两个张量中其中有一个是 1D 张量,torch.matmul 函数会将其广播成 2D 张量参与运算,最后将广播添加的维度删除作为最终 torch.matmul 函数的返回结果。torch.mm 函数不支持广播,相对应的输入的两个张量必须为 2D。

import torch

input = torch.tensor([[1., 2.], [3., 4.]])
other = torch.tensor([[5., 6., 7.], [8., 9., 10.]])

result = torch.mm(input, other)
print(result)
# tensor([[21., 24., 27.],
#         [47., 54., 61.]])

总结

到此这篇关于pytorch中常用的乘法运算及相关的运算符(@和*)的文章就介绍到这了,更多相关pytorch常用乘法运算及运算符内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • 基于python及pytorch中乘法的使用详解

    numpy中的乘法 A = np.array([[1, 2, 3], [2, 3, 4]]) B = np.array([[1, 0, 1], [2, 1, -1]]) C = np.array([[1, 0], [0, 1], [-1, 0]]) A * B : # 对应位置相乘 np.array([[ 1, 0, 3], [ 4, 3, -4]]) A.dot(B) : # 矩阵乘法 ValueError: shapes (2,3) and (2,3) not aligned: 3 (dim

  • pytorch中常用的乘法运算及相关的运算符(@和*)

    目录 前言 1.torch.mm 2.torch.bmm 3.torch.mul 4.torch.mv 5.torch.matmul 6.@运算符 7.*运算符 附:二维矩阵乘法 总结 前言 这里总结一下pytorch常用的乘法运算以及相关的运算符(@.*). 总结放前面: torch.mm : 用于两个矩阵(不包括向量)的乘法.如维度为(l,m)和(m,n)相乘 torch.bmm : 用于带batch的三维向量的乘法.如维度为(b,l,m)和(b,m,n)相乘 torch.mul : 用于两

  • PyTorch中clone()、detach()及相关扩展详解

    clone() 与 detach() 对比 Torch 为了提高速度,向量或是矩阵的赋值是指向同一内存的,这不同于 Matlab.如果需要保存旧的tensor即需要开辟新的存储地址而不是引用,可以用 clone() 进行深拷贝, 首先我们来打印出来clone()操作后的数据类型定义变化: (1). 简单打印类型 import torch a = torch.tensor(1.0, requires_grad=True) b = a.clone() c = a.detach() a.data *=

  • pytorch中常用的损失函数用法说明

    1. pytorch中常用的损失函数列举 pytorch中的nn模块提供了很多可以直接使用的loss函数, 比如MSELoss(), CrossEntropyLoss(), NLLLoss() 等 官方链接: https://pytorch.org/docs/stable/_modules/torch/nn/modules/loss.html pytorch中常用的损失函数 损失函数 名称 适用场景 torch.nn.MSELoss() 均方误差损失 回归 torch.nn.L1Loss() 平

  • 整理Java编程中常用的基本描述符与运算符

    描述符 描述符是你添加到那些定义中来改变他们的意思的关键词.Java 语言有很多描述符,包括以下这些: 可访问描述符 不可访问描述符 应用描述符,你可以在类.方法.变量中加入相应关键字.描述符要先于声明,如下面的例子所示(斜体): public class className { // ... } private boolean myFlag; static final double weeks = 9.5; protected static final int BOXWIDTH = 42; p

  • PyTorch中常用的激活函数的方法示例

    神经网络只是由两个或多个线性网络层叠加,并不能学到新的东西,简单地堆叠网络层,不经过非线性激活函数激活,学到的仍然是线性关系. 但是加入激活函数可以学到非线性的关系,就具有更强的能力去进行特征提取. 构造数据 import torch import torch.nn.functional as F from torch.autograd import Variable import matplotlib.pyplot as plt x = torch.linspace(-5, 5, 200) #

  • Pytorch实现常用乘法算子TensorRT的示例代码

    目录 1.乘法运算总览 2.乘法算子实现 2.1矩阵乘算子实现 2.2点乘算子实现 本文介绍一下 Pytorch 中常用乘法的 TensorRT 实现. pytorch 用于训练,TensorRT 用于推理是很多 AI 应用开发的标配.大家往往更加熟悉 pytorch 的算子,而不太熟悉 TensorRT 的算子,这里拿比较常用的乘法运算在两种框架下的实现做一个对比,可能会有更加直观一些的认识. 1.乘法运算总览 先把 pytorch 中的一些常用的乘法运算进行一个总览: torch.mm:用于

  • PyTorch中Tensor的数据类型和运算的使用

    在使用Tensor时,我们首先要掌握如何使用Tensor来定义不同数据类型的变量.Tensor时张量的英文,表示多维矩阵,和numpy对应,PyTorch中的Tensor可以和numpy的ndarray相互转换,唯一不同的是PyTorch可以在GPU上运行,而numpy的ndarray只能在cpu上运行. 常用的不同数据类型的Tensor,有32位的浮点型torch.FloatTensor,   64位浮点型 torch.DoubleTensor,   16位整形torch.ShortTenso

  • PyTorch中的拷贝与就地操作详解

    前言 PyTroch中我们经常使用到Numpy进行数据的处理,然后再转为Tensor,但是关系到数据的更改时我们要注意方法是否是共享地址,这关系到整个网络的更新.本篇就In-palce操作,拷贝操作中的注意点进行总结. In-place操作 pytorch中原地操作的后缀为_,如.add_()或.scatter_(),就地操作是直接更改给定Tensor的内容而不进行复制的操作,即不会为变量分配新的内存.Python操作类似+=或*=也是就地操作.(我加了我自己~) 为什么in-place操作可以

  • pytorch中的nn.ZeroPad2d()零填充函数实例详解

    在卷积神经网络中,有使用设置padding的参数,配合卷积步长,可以使得卷积后的特征图尺寸大小不发生改变,那么在手动实现图片或特征图的边界零填充时,常用的函数是nn.ZeroPad2d(),可以指定tensor的四个方向上的填充,比如左边添加1dim.右边添加2dim.上边添加3dim.下边添加4dim,即指定paddin参数为(1,2,3,4),本文中代码设置的是(3,4,5,6)如下: import torch.nn as nn import cv2 import torchvision f

  • Pytorch中的gather使用方法

    官方说明 gather可以对一个Tensor进行聚合,声明为:torch.gather(input, dim, index, out=None) → Tensor 一般来说有三个参数:输入的变量input.指定在某一维上聚合的dim.聚合的使用的索引index,输出为Tensor类型的结果(index必须为LongTensor类型). #参数介绍: input (Tensor) – The source tensor dim (int) – The axis along which to ind

随机推荐