Pytorch mask_select 函数的用法详解

非常简单的函数,但是官网的介绍令人(令我)迷惑,所以稍加解释。

mask_select会将满足mask(掩码、遮罩等等,随便翻译)的指示,将满足条件的点选出来。

根据掩码张量mask中的二元值,取输入张量中的指定项( mask为一个 ByteTensor),将取值返回到一个新的1D张量,

张量 mask须跟input张量有相同数量的元素数目,但形状或维度不需要相同

x = torch.randn(3, 4)
x
1.2045 2.4084 0.4001 1.1372
0.5596 1.5677 0.6219 -0.7954
1.3635 -1.2313 -0.5414 -1.8478
[torch.FloatTensor of size 3x4]
mask = x.ge(0.5)
mask
1 1 0 1
1 1 1 0
1 0 0 0
[torch.ByteTensor of size 3x4]
torch.masked_select(x, mask)
1.2045
2.4084
1.1372
0.5596
1.5677
0.6219
1.3635
[torch.FloatTensor of size 7]

以上这篇Pytorch mask_select 函数的用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • pytorch实现seq2seq时对loss进行mask的方式

    如何对loss进行mask pytorch官方教程中有一个Chatbot教程,就是利用seq2seq和注意力机制实现的,感觉和机器翻译没什么不同啊,如果对话中一句话有下一句,那么就把这一对句子加入模型进行训练.其中在训练阶段,损失函数通常需要进行mask操作,因为一个batch中句子的长度通常是不一样的,一个batch中不足长度的位置需要进行填充(pad)补0,最后生成句子计算loss时需要忽略那些原本是pad的位置的值,即只保留mask中值为1位置的值,忽略值为0位置的值,具体演示如下: im

  • python_mask_array的用法

    掩码数组 数据很大形况下是凌乱的,并且含有空白的或者无法处理的字符,掩码式数组可以很好的忽略残缺的或者是无效的数据点.掩码式数组由一个正常数组与一个布尔式数组组成,若布尔数组中为Ture,则表示正常数组中对应下标的值无效,反之False表示对应正常数组的值有效. numpy.ma模块中提供掩码数组的处理,这个模块中几乎完整复制了numpy中的所有函数,并提供掩码数组的功能: >>>import numpy.ma as ma >>>x = np.array([1,2,3,

  • python实现根据给定坐标点生成多边形mask的例子

    处理数据集的过程中用到了mask 但是源数据集中只给了mask顶点的坐标值,那么在python中怎么实现生成只有0.1表示的mask区域呢? 主要借鉴cv2中的方法: (我使用的数据情况是将顶点坐标分别存储在roi.mat中的x和y元素) matfn = 'roi.mat' data = sio.loadmat(matfn) x_cor = data['x'] y_cor = data['y'] im = np.zeros(图像对应尺寸, dtype="uint8") cor_xy =

  • Pytorch mask_select 函数的用法详解

    非常简单的函数,但是官网的介绍令人(令我)迷惑,所以稍加解释. mask_select会将满足mask(掩码.遮罩等等,随便翻译)的指示,将满足条件的点选出来. 根据掩码张量mask中的二元值,取输入张量中的指定项( mask为一个 ByteTensor),将取值返回到一个新的1D张量, 张量 mask须跟input张量有相同数量的元素数目,但形状或维度不需要相同 x = torch.randn(3, 4) x 1.2045 2.4084 0.4001 1.1372 0.5596 1.5677

  • PyTorch中topk函数的用法详解

    听名字就知道这个函数是用来求tensor中某个dim的前k大或者前k小的值以及对应的index. 用法 torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor) input:一个tensor数据 k:指明是得到前k个数据以及其index dim: 指定在哪个维度上排序, 默认是最后一个维度 largest:如果为True,按照大到小排序: 如果为False,按照小到大排序

  • PyTorch中permute的用法详解

    permute(dims) 将tensor的维度换位. 参数:参数是一系列的整数,代表原来张量的维度.比如三维就有0,1,2这些dimension. 例: import torch import numpy as np a=np.array([[[1,2,3],[4,5,6]]]) unpermuted=torch.tensor(a) print(unpermuted.size()) # --> torch.Size([1, 2, 3]) permuted=unpermuted.permute(

  • Pytorch 中retain_graph的用法详解

    用法分析 在查看SRGAN源码时有如下损失函数,其中设置了retain_graph=True,其作用是什么? ############################ # (1) Update D network: maximize D(x)-1-D(G(z)) ########################### real_img = Variable(target) if torch.cuda.is_available(): real_img = real_img.cuda() z = V

  • pytorch中index_select()的用法详解

    pytorch中index_select()的用法 index_select(input, dim, index) 功能:在指定的维度dim上选取数据,不如选取某些行,列 参数介绍 第一个参数input是要索引查找的对象 第二个参数dim是要查找的维度,因为通常情况下我们使用的都是二维张量,所以可以简单的记忆: 0代表行,1代表列 第三个参数index是你要索引的序列,它是一个tensor对象 刚开始学习pytorch,遇到了index_select(),一开始不太明白几个参数的意思,后来查了一

  • Oracle中的INSTR,NVL和SUBSTR函数的用法详解

    Oracle中INSTR的用法: INSTR方法的格式为 INSTR(源字符串, 要查找的字符串, 从第几个字符开始, 要找到第几个匹配的序号) 返回找到的位置,如果找不到则返回0. 例如:INSTR('CORPORATE FLOOR','OR', 3, 2)中,源字符串为'CORPORATE FLOOR', 在字符串中查找'OR',从第三个字符位置开始查找"OR",取第三个字后第2个匹配项的位置. 默认查找顺序为从左到右.当起始位置为负数的时候,从右边开始查找. 所以SELECT I

  • ES6中Array.find()和findIndex()函数的用法详解

    ES6为Array增加了find(),findIndex函数. find()函数用来查找目标元素,找到就返回该元素,找不到返回undefined. findIndex()函数也是查找目标元素,找到就返回元素的位置,找不到就返回-1. 他们的都是一个查找回调函数. [1, 2, 3, 4].find((value, index, arr) => { }) 查找函数有三个参数. value:每一次迭代查找的数组元素. index:每一次迭代查找的数组元素索引. arr:被查找的数组. 例: 1.查找

  • C语言中memcpy 函数的用法详解

    C语言中memcpy 函数的用法详解 memcpy(内存拷贝函数) c和c++使用的内存拷贝函数,memcpy函数的功能是从源src所指的内存地址的起始位置开始拷贝n个字节到目标dest所指的内存地址的起始位置中. void* memcpy(void* destination, const void* source, size_t num); void* dest 目标内存 const void* src 源内存 size_t num 字节个数 库中实现的memcpy函数 struct { ch

  • 对pandas中apply函数的用法详解

    最近在使用apply函数,总结一下用法. apply函数可以对DataFrame对象进行操作,既可以作用于一行或者一列的元素,也可以作用于单个元素. 例:列元素 行元素 列 行 以上这篇对pandas中apply函数的用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们. 您可能感兴趣的文章: 浅谈Pandas中map, applymap and apply的区别

  • python 函数中的内置函数及用法详解

    今天来介绍一下Python解释器包含的一系列的内置函数,下面表格按字母顺序列出了内置函数: 下面就一一介绍一下内置函数的用法: 1.abs() 返回一个数值的绝对值,可以是整数或浮点数等. print(abs(-18)) print(abs(0.15)) result: 18 0.15 2.all(iterable) 如果iterable的所有元素不为0.''.False或者iterable为空,all(iterable)返回True,否则返回False. print(all(['a','b',

随机推荐