详解torch.Tensor的4种乘法

torch.Tensor有4种常见的乘法:*, torch.mul, torch.mm, torch.matmul. 本文抛砖引玉,简单叙述一下这4种乘法的区别,具体使用还是要参照官方文档

点乘

a与b做*乘法,原则是如果a与b的size不同,则以某种方式将a或b进行复制,使得复制后的a和b的size相同,然后再将a和b做element-wise的乘法

下面以*标量和*一维向量为例展示上述过程。

* 标量

Tensor与标量k做*乘法的结果是Tensor的每个元素乘以k(相当于把k复制成与lhs大小相同,元素全为k的Tensor).

>>> a = torch.ones(3,4)
>>> a
tensor([[1., 1., 1., 1.],
    [1., 1., 1., 1.],
    [1., 1., 1., 1.]])
>>> a * 2
tensor([[2., 2., 2., 2.],
    [2., 2., 2., 2.],
    [2., 2., 2., 2.]])

* 一维向量

Tensor与行向量做*乘法的结果是每列乘以行向量对应列的值(相当于把行向量的行复制,成为与lhs维度相同的Tensor). 注意此时要求Tensor的列数与行向量的列数相等。

>>> a = torch.ones(3,4)
>>> a
tensor([[1., 1., 1., 1.],
    [1., 1., 1., 1.],
    [1., 1., 1., 1.]])
>>> b = torch.Tensor([1,2,3,4])
>>> b
tensor([1., 2., 3., 4.])
>>> a * b
tensor([[1., 2., 3., 4.],
    [1., 2., 3., 4.],
    [1., 2., 3., 4.]])

Tensor与列向量做*乘法的结果是每行乘以列向量对应行的值(相当于把列向量的列复制,成为与lhs维度相同的Tensor). 注意此时要求Tensor的行数与列向量的行数相等。

>>> a = torch.ones(3,4)
>>> a
tensor([[1., 1., 1., 1.],
    [1., 1., 1., 1.],
    [1., 1., 1., 1.]])
>>> b = torch.Tensor([1,2,3]).reshape((3,1))
>>> b
tensor([[1.],
    [2.],
    [3.]])
>>> a * b
tensor([[1., 1., 1., 1.],
    [2., 2., 2., 2.],
    [3., 3., 3., 3.]])

* 矩阵

Arsmart在评论区提醒,增补一个矩阵 * 矩阵的例子,感谢Arsmart的热心评论!
如果两个二维矩阵A与B做点积A * B,则要求A与B的维度完全相同,即A的行数=B的行数,A的列数=B的列数

>>> a = torch.tensor([[1, 2], [2, 3]])
>>> a * a
tensor([[1, 4],
    [4, 9]])

broadcast

点积是broadcast的。broadcast是torch的一个概念,简单理解就是在一定的规则下允许高维Tensor和低维Tensor之间的运算。broadcast的概念稍显复杂,在此不做展开,可以参考官方文档关于broadcast的介绍. 在torch.matmul里会有关于broadcast的应用的一个简单的例子。

这里举一个点积broadcast的例子。在例子中,a是二维Tensor,b是三维Tensor,但是a的维度与b的后两位相同,那么a和b仍然可以做点积,点积结果是一个和b维度一样的三维Tensor,运算规则是:若c = a * b, 则c[i,*,*] = a * b[i, *, *],即沿着b的第0维做二维Tensor点积,或者可以理解为运算前将a沿着b的第0维也进行了expand操作,即a = a.expand(b.size()); a * b

>>> a = torch.tensor([[1, 2], [2, 3]])
>>> b = torch.tensor([[[1,2],[2,3]],[[-1,-2],[-2,-3]]])
>>> a * b
tensor([[[ 1, 4],
     [ 4, 9]],

    [[-1, -4],
     [-4, -9]]])
>>> b * a
tensor([[[ 1, 4],
     [ 4, 9]],

    [[-1, -4],
     [-4, -9]]])

其实,上面提到的二维Tensor点积标量、二维Tensor点积行向量,都是发生在高维向量和低维向量之间的,也可以看作是broadcast.

torch.mul

官方文档关于torch.mul的介绍. 用法与*乘法相同,也是element-wise的乘法,也是支持broadcast的。

下面是几个torch.mul的例子.

乘标量

>>> a = torch.ones(3,4)
>>> a
tensor([[1., 1., 1., 1.],
    [1., 1., 1., 1.],
    [1., 1., 1., 1.]])
>>> a * 2
tensor([[2., 2., 2., 2.],
    [2., 2., 2., 2.],
    [2., 2., 2., 2.]])

乘行向量

>>> a = torch.ones(3,4)
>>> a
tensor([[1., 1., 1., 1.],
    [1., 1., 1., 1.],
    [1., 1., 1., 1.]])
>>> b = torch.Tensor([1,2,3,4])
>>> b
tensor([1., 2., 3., 4.])
>>> torch.mul(a, b)
tensor([[1., 2., 3., 4.],
    [1., 2., 3., 4.],
    [1., 2., 3., 4.]])

乘列向量

>>> a = torch.ones(3,4)
>>> a
tensor([[1., 1., 1., 1.],
    [1., 1., 1., 1.],
    [1., 1., 1., 1.]])
>>> b = torch.Tensor([1,2,3]).reshape((3,1))
>>> b
tensor([[1.],
    [2.],
    [3.]])
>>> torch.mul(a, b)
tensor([[1., 1., 1., 1.],
    [2., 2., 2., 2.],
    [3., 3., 3., 3.]])

乘矩阵

例1:二维矩阵 mul 二维矩阵

>>> a = torch.tensor([[1, 2], [2, 3]])
>>> torch.mul(a,a)
tensor([[1, 4],
    [4, 9]])

例2:二维矩阵 mul 三维矩阵(broadcast)

>>> a = torch.tensor([[1, 2], [2, 3]])
>>> b = torch.tensor([[[1,2],[2,3]],[[-1,-2],[-2,-3]]])
>>> torch.mul(a,b)
tensor([[[ 1, 4],
     [ 4, 9]],

    [[-1, -4],
     [-4, -9]]])

torch.mm

官方文档关于torch.mm的介绍. 数学里的矩阵乘法,要求两个Tensor的维度满足矩阵乘法的要求.

例子:

>>> a = torch.ones(3,4)
>>> b = torch.ones(4,2)
>>> torch.mm(a, b)
tensor([[4., 4.],
    [4., 4.],
    [4., 4.]])

torch.matmul

官方文档关于torch.matmul的介绍. torch.mm的broadcast版本.

例子:

>>> a = torch.ones(3,4)
>>> b = torch.ones(5,4,2)
>>> torch.matmul(a, b)
tensor([[[4., 4.],
     [4., 4.],
     [4., 4.]],

    [[4., 4.],
     [4., 4.],
     [4., 4.]],

    [[4., 4.],
     [4., 4.],
     [4., 4.]],

    [[4., 4.],
     [4., 4.],
     [4., 4.]],

    [[4., 4.],
     [4., 4.],
     [4., 4.]]])

同样的a和b,使用torch.mm相乘会报错

>>> torch.mm(a, b)
Traceback (most recent call last):
 File "<stdin>", line 1, in <module>
RuntimeError: matrices expected, got 2D, 3D tensors at /pytorch/aten/src/TH/generic/THTensorMath.cpp:2065

到此这篇关于详解torch.Tensor的4种乘法的文章就介绍到这了,更多相关torch.Tensor 乘法内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • pytorch查看torch.Tensor和model是否在CUDA上的实例

    今天训练faster R-CNN时,发现之前跑的很好的程序(是指在运行程序过程中,显卡利用率能够一直维持在70%以上),今天看的时候,显卡利用率很低,所以在想是不是我的训练数据torch.Tensor或者模型model没有加载到GPU上训练,于是查找如何查看tensor和model所在设备的命令. import torch import torchvision.models as models model=models.vgg11(pretrained=False) print(next(mod

  • PyTorch中torch.tensor与torch.Tensor的区别详解

    PyTorch最近几年可谓大火.相比于TensorFlow,PyTorch对于Python初学者更为友好,更易上手. 众所周知,numpy作为Python中数据分析的专业第三方库,比Python自带的Math库速度更快.同样的,在PyTorch中,有一个类似于numpy的库,称为Tensor.Tensor自称为神经网络界的numpy. 一.numpy和Tensor二者对比 对比项 numpy Tensor 相同点 可以定义多维数组,进行切片.改变维度.数学运算等 可以定义多维数组,进行切片.改变

  • 详解torch.Tensor的4种乘法

    torch.Tensor有4种常见的乘法:*, torch.mul, torch.mm, torch.matmul. 本文抛砖引玉,简单叙述一下这4种乘法的区别,具体使用还是要参照官方文档. 点乘 a与b做*乘法,原则是如果a与b的size不同,则以某种方式将a或b进行复制,使得复制后的a和b的size相同,然后再将a和b做element-wise的乘法. 下面以*标量和*一维向量为例展示上述过程. * 标量 Tensor与标量k做*乘法的结果是Tensor的每个元素乘以k(相当于把k复制成与l

  • 详解pytorch tensor和ndarray转换相关总结

    在使用pytorch的时候,经常会涉及到两种数据格式tensor和ndarray之间的转换,这里总结一下两种格式的转换: 1. tensor cpu 和tensor gpu之间的转化: tensor cpu 转为tensor gpu: tensor_gpu = tensor_cpu.cuda() >>> tensor_cpu = torch.ones((2,2)) tensor([[1., 1.], [1., 1.]]) >>> tensor_gpu = tensor_

  • 详解TensorFlow训练网络两种方式

    TensorFlow训练网络有两种方式,一种是基于tensor(array),另外一种是迭代器 两种方式区别是: 第一种是要加载全部数据形成一个tensor,然后调用model.fit()然后指定参数batch_size进行将所有数据进行分批训练 第二种是自己先将数据分批形成一个迭代器,然后遍历这个迭代器,分别训练每个批次的数据 方式一:通过迭代器 IMAGE_SIZE = 1000 # step1:加载数据集 (train_images, train_labels), (val_images,

  • 详解JavaScript中的4种类型识别方法

    具体内容如下: 1.typeof [输出]首字母小写的字符串形式 [功能] [a]可以识别标准类型(将Null识别为object) [b]不能识别具体的对象类型(Function除外) [实例] console.log(typeof "jerry");//"string" console.log(typeof 12);//"number" console.log(typeof true);//"boolean" console

  • 详解java中的四种代码块

    在java中用{}括起来的称为代码块,代码块可分为以下四种: 一.简介 1.普通代码块: 类中方法的方法体 2.构造代码块: 构造块会在创建对象时被调用,每次创建时都会被调用,优先于类构造函数执行. 3.静态代码块: 用static{}包裹起来的代码片段,只会执行一次.静态代码块优先于构造块执行. 4.同步代码块: 使用synchronized(){}包裹起来的代码块,在多线程环境下,对共享数据的读写操作是需要互斥进行的,否则会导致数据的不一致性.同步代码块需要写在方法中. 二.静态代码块和构造

  • 详解C++ 多态的两种形式(静态、动态)

    1.多态的概念与分类 多态(Polymorphisn)是面向对象程序设计(OOP)的一个重要特征.多态字面意思为多种状态.在面向对象语言中,一个接口,多种实现即为多态.C++中的多态性具体体现在编译和运行两个阶段.编译时多态是静态多态,在编译时就可以确定使用的接口.运行时多态是动态多态,具体引用的接口在运行时才能确定. 静态多态和动态多态的区别其实只是在什么时候将函数实现和函数调用关联起来,是在编译时期还是运行时期,即函数地址是早绑定还是晚绑定的.静态多态是指在编译期间就可以确定函数的调用地址,

  • 详解python中的三种命令行模块(sys.argv,argparse,click)

    Python作为一门脚本语言,经常作为脚本接受命令行传入参数,Python接受命令行参数大概有三种方式.因为在日常工作场景会经常使用到,这里对这几种方式进行总结. 命令行参数模块 这里命令行参数模块平时工作中用到最多就是这三种模块:sys.argv,argparse,click.sys.argv和argparse都是内置模块,click则是第三方模块. sys.argv模块(内置模块) 先看一个简单的示例: #!/usr/bin/python import sys def hello(name,

  • 详解js中的几种常用设计模式

    工厂模式 function createPerson(name, age){ var o = new Object(); // 创建一个对象 o.name = name; o.age = age; o.sayName = function(){ console.log(this.name) } return o; // 返回这个对象 } var person1 = createPerson('ccc', 18) var person2 = createPerson('www', 18) 工厂函数

  • 详解servlet调用的几种简单方式总结

    servlet调用的几种简单方式 这里总结的是我在学习web开发的过程中需要用到的几种比较常见的用于转发和调用servlet的方式,这些方式的使用率非常高.在网上总结了相关的方法,大多对于初学者不是特别的友好,自己总结了一下. 1.servlet直接转发到另一个servlet 我们在进行jsp页面点击按钮进行登录的时候,首先需要登录到进行登录检查的servlet,但是在下个jsp页面,我们需要那个页面通过servlet进行转发,所以需要从servlet直接跳转到另一个servlet,其实写法很简

  • 详解js创建对象的几种方式和对象方法

    这篇文章是看js红宝书第8章,记的关于对象的笔记(第二篇). 创建对象的几种模式: 工厂模式: 工厂是函数的意思.工厂模式核心是定义一个返回全新对象的函数. function getObj(name, age) { let obj = {} obj.name = name obj.age = age return obj } let person1 = getObj("cc", 31) 缺点:不知道新创建的对象是什么类型 构造函数模式: 通过一个构造函数,得到一个对象实例. 构造函数和

随机推荐