详解PyTorch中Tensor的高阶操作

条件选取:torch.where(condition, x, y) → Tensor

返回从 x 或 y 中选择元素的张量,取决于 condition

操作定义:

举个例子:

>>> import torch
>>> c = randn(2, 3)
>>> c
tensor([[ 0.0309, -1.5993, 0.1986],
    [-0.0699, -2.7813, -1.1828]])
>>> a = torch.ones(2, 3)
>>> a
tensor([[1., 1., 1.],
    [1., 1., 1.]])
>>> b = torch.zeros(2, 3)
>>> b
tensor([[0., 0., 0.],
    [0., 0., 0.]])
>>> torch.where(c > 0, a, b)
tensor([[1., 0., 1.],
    [0., 0., 0.]])

把张量中的每个数据都代入条件中,如果其大于 0 就得出 a,其它情况就得出 b,同样是把 a 和 b 的相同位置的数据导出。

查表搜集:torch.gather(input, dim, index, out=None) → Tensor

沿给定轴 dim,将输入索引张量 index 指定位置的值进行聚合

对一个3维张量,输出可以定义为:

  • out[i][j][k] = tensor[index[i][j][k]][j][k] # dim=0
  • out[i][j][k] = tensor[i][index[i][j][k]][k] # dim=1
  • out[i][j][k] = tensor[i][j][index[i][j][k]] # dim=3

举个例子:

>>> a = torch.randn(4, 10)
>>> b = a.topk(3, dim = 1)
>>> b
(tensor([[ 1.0134, 0.8785, -0.0373],
    [ 1.4378, 1.4022, 1.0115],
    [ 0.8985, 0.6795, 0.6439],
    [ 1.2758, 1.0294, 1.0075]]), tensor([[5, 7, 6],
    [2, 5, 8],
    [5, 9, 2],
    [7, 9, 6]]))
>>> index = b[1]
>>> index
tensor([[5, 7, 6],
    [2, 5, 8],
    [5, 9, 2],
    [7, 9, 6]])
>>> label = torch.arange(10) + 100
>>> label
tensor([100, 101, 102, 103, 104, 105, 106, 107, 108, 109])
>>> torch.gather(label.expand(4, 10), dim=1, index=index.long()) # 进行聚合操作
tensor([[105, 107, 106],
    [102, 105, 108],
    [105, 109, 102],
    [107, 109, 106]])

把 label 扩展为二维数据后,以 index 中的每个数据为索引,取出在 label 中索引位置的数据,再以 index 的的位置摆放。

比如,最后得出的结果中,第一行的 105 就是 label.expand(4, 10) 中第一行中索引为 5 的数据,提取出来后放在 5 所在的位置。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持我们。

(0)

相关推荐

  • Pytorch Tensor的索引与切片例子

    1. Pytorch风格的索引 根据Tensor的shape,从前往后索引,依次在每个维度上做索引. 示例代码: import torch a = torch.rand(4, 3, 28, 28) print(a[0].shape) #取到第一个维度 print(a[0, 0].shape) # 取到二个维度 print(a[1, 2, 2, 4]) # 具体到某个元素 上述代码创建了一个shape=[4, 3, 28, 28]的Tensor,我们可以理解为4张图片,每张图片有3个通道,每个通道

  • pytorch中tensor的合并与截取方法

    合并: torch.cat(inputs=(a, b), dimension=1) e.g. x = torch.cat((x,y), 0) 沿x轴合并 截取: x[:, 2:4] 以上这篇pytorch中tensor的合并与截取方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们.

  • pytorch: tensor类型的构建与相互转换实例

    Summary 主要包括以下三种途径: 使用独立的函数: 使用torch.type()函数: 使用type_as(tesnor)将张量转换为给定类型的张量. 使用独立函数 import torch tensor = torch.randn(3, 5) print(tensor) # torch.long() 将tensor投射为long类型 long_tensor = tensor.long() print(long_tensor) # torch.half()将tensor投射为半精度浮点类型

  • 在PyTorch中Tensor的查找和筛选例子

    本文源码基于版本1.0,交互界面基于0.4.1 import torch 按照指定轴上的坐标进行过滤 index_select() 沿着某tensor的一个轴dim筛选若干个坐标 >>> x = torch.randn(3, 4) # 目标矩阵 >>> x tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], [-0.4664, 0.2647, -0.1228, -1.1068], [-1.1734, -0.6571, 0.7230,

  • PyTorch之图像和Tensor填充的实例

    在PyTorch中可以对图像和Tensor进行填充,如常量值填充,镜像填充和复制填充等.在图像预处理阶段设置图像边界填充的方式如下: import vision.torchvision.transforms as transforms img_to_pad = transforms.Compose([ transforms.Pad(padding=2, padding_mode='symmetric'), transforms.ToTensor(), ]) 对Tensor进行填充的方式如下: i

  • 详解PyTorch中Tensor的高阶操作

    条件选取:torch.where(condition, x, y) → Tensor 返回从 x 或 y 中选择元素的张量,取决于 condition 操作定义: 举个例子: >>> import torch >>> c = randn(2, 3) >>> c tensor([[ 0.0309, -1.5993, 0.1986], [-0.0699, -2.7813, -1.1828]]) >>> a = torch.ones(2,

  • 详解pytorch中squeeze()和unsqueeze()函数介绍

    squeeze的用法主要就是对数据的维度进行压缩或者解压. 先看torch.squeeze() 这个函数主要对数据的维度进行压缩,去掉维数为1的的维度,比如是一行或者一列这种,一个一行三列(1,3)的数去掉第一个维数为一的维度之后就变成(3)行.squeeze(a)就是将a中所有为1的维度删掉.不为1的维度没有影响.a.squeeze(N) 就是去掉a中指定的维数为一的维度.还有一种形式就是b=torch.squeeze(a,N) a中去掉指定的定的维数为一的维度. 再看torch.unsque

  • 详解Angular路由动画及高阶动画函数

    一.路由动画 路由动画需要在host元数据中指定触发器.动画注意不要过多,否则适得其反. 内容优先,引导用户去注意到某个内容.动画只是辅助手段. 在router.animate.ts中定义一个进场动画,一个离场动画. 因为进场动画和离场动画用的特别频繁,有一个别名叫:enter和:leave. import { trigger, state, transition, style, animate} from '@angular/animations'; export const slideToR

  • 详解Go语言如何利用高阶函数写出优雅的代码

    目录 前言 问题 白银 黄金 王者 总结 前言 go项目中经常需要查询db,按照以前java开发经验,会根据查询条件写很多方法,如: GetUserByUserID GetUsersByName GetUsersByAge 每一种查询条件写一个方法,这种方式对外是挺好的,对外遵循严格原则,让每个对外的方法接口是明确的.但是对内的话,应该尽可能的通用,做到代码复用,少写代码,让代码看起来更优雅.整洁. 问题 在review代码的时候,针对上面3个方法,一般写法是 func GetUserByUse

  • 详解python中的文件与目录操作

    详解python中的文件与目录操作 一 获得当前路径 1.代码1 >>>import os >>>print('Current directory is ',os.getcwd()) Current directory is D:\Python36 2.代码2 如果将上面的脚本写入到文件再运行 Current directory is E:\python\work 二 获得目录的内容 Python代码 >>> os.listdir (os.getcwd

  • 详解Matlab中自带的Java操作合集

    目录 1 获取鼠标在全屏位置 2 获取当前剪切板内容 3 内容复制到剪切板 4 获取鼠标处像素颜色 5 获取屏幕截图 6 创建java窗口(并使其永远在最上方) 7 透明窗口 1 获取鼠标在全屏位置 屏幕左上角为坐标原点,获取鼠标位置和获取鼠标像素颜色建议和while循环或者timer函数结合使用: import java.awt.MouseInfo; mousepoint=MouseInfo.getPointerInfo().getLocation(); mousepoint=[mousepo

  • 详解Python中键盘鼠标的相关操作

    目录 一.前言 二.pyautogui模块 三.鼠标相关操作 1.鼠标移动 2.获取鼠标位置 3.鼠标点击 4.按松鼠标 5.拖动窗口 6.上下滑动 7.小程序——鼠标操控术2.0 8.小程序——连点器 四.键盘相关操作 1.按键的按松 2.键入字符串 3.热键 4.小程序——轰炸器 5.小程序——520个我爱你 五.尾声 一.前言 恭喜你,学明白类,你已经学会所有基本知识了. 这章算是一个娱乐篇,十分简单,了解一下pyautogui模块,这算是比较好学还趣味性十足的,而且可以做许多小程序. 本

  • 详解Pytorch中的tensor数据结构

    目录 torch.Tensor Tensor 数据类型 view 和 reshape 的区别 Tensor 与 ndarray 创建 Tensor 传入维度的方法 torch.Tensor torch.Tensor 是一种包含单一数据类型元素的多维矩阵,类似于 numpy 的 array.Tensor 可以使用 torch.tensor() 转换 Python 的 list 或序列数据生成,生成的是dtype 默认是 torch.FloatTensor. 注意 torch.tensor() 总是

  • 详解python内置常用高阶函数(列出了5个常用的)

    高阶函数是在Python中一个非常有用的功能函数,所谓高阶函数就是一个函数可以用来接收另一个函数作为参数,这样的函数叫做高阶函数. python内置常用高阶函数: 一.函数式编程 •函数本身可以赋值给变量,赋值后变量为函数: •允许将函数本身作为参数传入另一个函数: •允许返回一个函数. 1.map()函数 是 Python 内置的高阶函数,它接收一个函数 f 和一个 list, 并通过把函数 f 依次作用在 list 的每个元素上,得到一个新的 list 并返回 def add(x): ret

  • 详解PHP中的序列化、反序列化操作

    数据(变量)序列化(持久化) 将一个变量的数据"转换为"字符串,但并不是类型转换,目的是将该字符串存储在本地.相反的行为成为反序列化. 流程: //序列化 $str = serialize($r1); //保存到本地 file_put_contents("文本文件路径",$str); //从本地取出 $str2 = file_get_contents("文本文件路径"); //反序列化为之前的对象 $v1 = unserialize($str2)

随机推荐