pytorch中torch.topk()函数的快速理解

目录
  • 函数作用:
  • 举个栗子:
  • 实例演示
  • 总结

函数作用:

该函数的作用即按字面意思理解,topk:取数组的前k个元素进行排序。

通常该函数返回2个值,第一个值为排序的数组,第二个值为该数组中获取到的元素在原数组中的位置标号。

举个栗子:

import numpy as np
import torch
import torch.utils.data.dataset as Dataset
from torch.utils.data import Dataset,DataLoader

####################准备一个数组#########################
tensor1=torch.tensor([[10,1,2,1,1,1,1,1,1,1,10],
             [3,4,5,1,1,1,1,1,1,1,1],
             [7,8,9,1,1,1,1,1,1,1,1],
             [1,4,7,1,1,1,1,1,1,1,1]],dtype=torch.float32)

####################打印这个原数组#########################
print('tensor1:')
print(tensor1)

#################使用torch.topk()这个函数##################
print('使用torch.topk()这个函数得到:')

'''k=3代表从原数组中取得3个元素,dim=1表示从原数组中的第一维获取元素
(在本例中是分别从[10,1,2,1,1,1,1,1,1,1,10]、[3,4,5,1,1,1,1,1,1,1,1]、
  [7,8,9,1,1,1,1,1,1,1,1]、[1,4,7,1,1,1,1,1,1,1,1]这四个数组中获取3个元素)
其中largest=True表示从大到小取元素'''
print(torch.topk(tensor1, k=3, dim=1, largest=True))

#################打印这个函数第一个返回值####################
print('函数第一个返回值topk[0]如下')
print(torch.topk(tensor1, k=3, dim=1, largest=True)[0])

#################打印这个函数第二个返回值####################
print('函数第二个返回值topk[1]如下')
print(torch.topk(tensor1, k=3, dim=1, largest=True)[1])
'''

#######################运行结果##########################
tensor1:
tensor([[10.,  1.,  2.,  1.,  1.,  1.,  1.,  1.,  1.,  1., 10.],
        [ 3.,  4.,  5.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
        [ 7.,  8.,  9.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
        [ 1.,  4.,  7.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.]])

使用torch.topk()这个函数得到:

'得到的values是原数组dim=1的四组从大到小的三个元素值;
得到的indices是获取到的元素值在原数组dim=1中的位置。'

torch.return_types.topk(
values=tensor([[10., 10.,  2.],
        [ 5.,  4.,  3.],
        [ 9.,  8.,  7.],
        [ 7.,  4.,  1.]]),
indices=tensor([[ 0, 10,  2],
        [ 2,  1,  0],
        [ 2,  1,  0],
        [ 2,  1,  0]]))

函数第一个返回值topk[0]如下
tensor([[10., 10.,  2.],
        [ 5.,  4.,  3.],
        [ 9.,  8.,  7.],
        [ 7.,  4.,  1.]])

函数第二个返回值topk[1]如下
tensor([[ 0, 10,  2],
        [ 2,  1,  0],
        [ 2,  1,  0],
        [ 2,  1,  0]])
'''

该函数功能经常用来获取张量或者数组中最大或者最小的元素以及索引位置,是一个经常用到的基本函数。

实例演示

任务一:

取top1(最大值):

pred = torch.tensor([[-0.5816, -0.3873, -1.0215, -1.0145,  0.4053],
        [ 0.7265,  1.4164,  1.3443,  1.2035,  1.8823],
        [-0.4451,  0.1673,  1.2590, -2.0757,  1.7255],
        [ 0.2021,  0.3041,  0.1383,  0.3849, -1.6311]])
print(pred)
values, indices = pred.topk(1, dim=0, largest=True, sorted=True)
print(indices)
print(values)
# 用max得到的结果,设置keepdim为True,避免降维。因为topk函数返回的index不降维,shape和输入一致。
_, indices_max = pred.max(dim=0, keepdim=True)
print(indices_max)
print(indices_max == indices)
输出:
tensor([[-0.5816, -0.3873, -1.0215, -1.0145,  0.4053],
        [ 0.7265,  1.4164,  1.3443,  1.2035,  1.8823],
        [-0.4451,  0.1673,  1.2590, -2.0757,  1.7255],
        [ 0.2021,  0.3041,  0.1383,  0.3849, -1.6311]])
tensor([[1, 1, 1, 1, 1]])
tensor([[0.7265, 1.4164, 1.3443, 1.2035, 1.8823]])
tensor([[1, 1, 1, 1, 1]])
tensor([[True, True, True, True, True]])

任务二:

按行取出topk,将小于topk的置为inf:

pred = torch.tensor([[-0.5816, -0.3873, -1.0215, -1.0145,  0.4053],
        [ 0.7265,  1.4164,  1.3443,  1.2035,  1.8823],
        [-0.4451,  0.1673,  1.2590, -2.0757,  1.7255],
        [ 0.2021,  0.3041,  0.1383,  0.3849, -1.6311]])
print(pred)
top_k = 2  # 按行求出每一行的最大的前两个值
filter_value=-float('Inf')
indices_to_remove = pred < torch.topk(pred, top_k)[0][..., -1, None]
print(indices_to_remove)
pred[indices_to_remove] = filter_value  # 对于topk之外的其他元素的logits值设为负无穷
print(pred)

输出:
tensor([[-0.5816, -0.3873, -1.0215, -1.0145,  0.4053],
        [ 0.7265,  1.4164,  1.3443,  1.2035,  1.8823],
        [-0.4451,  0.1673,  1.2590, -2.0757,  1.7255],
        [ 0.2021,  0.3041,  0.1383,  0.3849, -1.6311]])
tensor([[4],
        [4],
        [4],
        [3]])
tensor([[0.4053],
        [1.8823],
        [1.7255],
        [0.3849]])
tensor([[ True, False,  True,  True, False],
        [ True, False,  True,  True, False],
        [ True,  True, False,  True, False],
        [ True, False,  True, False,  True]])
tensor([[   -inf, -0.3873,    -inf,    -inf,  0.4053],
        [   -inf,  1.4164,    -inf,    -inf,  1.8823],
        [   -inf,    -inf,  1.2590,    -inf,  1.7255],
        [   -inf,  0.3041,    -inf,  0.3849,    -inf]])

任务三:

import numpy as np
import torch
import torch.utils.data.dataset as Dataset
from torch.utils.data import Dataset,DataLoader
tensor1=torch.tensor([[10,1,2,1,1,1,1,1,1,1,10],
             [3,4,5,1,1,1,1,1,1,1,1],
             [7,8,9,1,1,1,1,1,1,1,1],
             [1,4,7,1,1,1,1,1,1,1,1]],dtype=torch.float32)
# tensor2=torch.tensor([[3,2,1],
#                       [6,5,4],
#                       [1,4,7],
#                       [9,8,7]],dtype=torch.float32)
#
print('tensor1:')
print(tensor1)
print('直接输出topk,会得到两个东西,我们需要的是第二个indices')
print(torch.topk(tensor1, k=3, dim=1, largest=True))
print('topk[0]如下')
print(torch.topk(tensor1, k=3, dim=1, largest=True)[0])
print('topk[1]如下')
print(torch.topk(tensor1, k=3, dim=1, largest=True)[1])
'''
tensor1:
tensor([[10.,  1.,  2.,  1.,  1.,  1.,  1.,  1.,  1.,  1., 10.],
        [ 3.,  4.,  5.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
        [ 7.,  8.,  9.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
        [ 1.,  4.,  7.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.]])
直接输出topk,会得到两个东西,我们需要的是第二个indices
torch.return_types.topk(
values=tensor([[10., 10.,  2.],
        [ 5.,  4.,  3.],
        [ 9.,  8.,  7.],
        [ 7.,  4.,  1.]]),
indices=tensor([[ 0, 10,  2],
        [ 2,  1,  0],
        [ 2,  1,  0],
        [ 2,  1,  0]]))
topk[0]如下
tensor([[10., 10.,  2.],
        [ 5.,  4.,  3.],
        [ 9.,  8.,  7.],
        [ 7.,  4.,  1.]])
topk[1]如下
tensor([[ 0, 10,  2],
        [ 2,  1,  0],
        [ 2,  1,  0],
        [ 2,  1,  0]])
'''

总结

到此这篇关于pytorch中torch.topk()函数快速理解的文章就介绍到这了,更多相关pytorch torch.topk()函数理解内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • PyTorch中topk函数的用法详解

    听名字就知道这个函数是用来求tensor中某个dim的前k大或者前k小的值以及对应的index. 用法 torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor) input:一个tensor数据 k:指明是得到前k个数据以及其index dim: 指定在哪个维度上排序, 默认是最后一个维度 largest:如果为True,按照大到小排序: 如果为False,按照小到大排序

  • pytorch中torch.topk()函数的快速理解

    目录 函数作用: 举个栗子: 实例演示 总结 函数作用: 该函数的作用即按字面意思理解,topk:取数组的前k个元素进行排序. 通常该函数返回2个值,第一个值为排序的数组,第二个值为该数组中获取到的元素在原数组中的位置标号. 举个栗子: import numpy as np import torch import torch.utils.data.dataset as Dataset from torch.utils.data import Dataset,DataLoader ########

  • Pytorch中torch.stack()函数的深入解析

    目录 一. torch.stack()函数解析 1. 函数说明: 2. 代码举例 总结 一. torch.stack()函数解析 1. 函数说明: 1.1 官网:torch.stack(),函数定义及参数说明如下图所示: 1.2 函数功能 沿一个新维度对输入一系列张量进行连接,序列中所有张量应为相同形状,stack 函数返回的结果会新增一个维度.也即是把多个2维的张量凑成一个3维的张量:多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度上面进行堆叠. 1.3 参数列表 tensors

  • 浅谈pytorch中torch.max和F.softmax函数的维度解释

    在利用torch.max函数和F.Ssoftmax函数时,对应该设置什么维度,总是有点懵,遂总结一下: 首先看看二维tensor的函数的例子: import torch import torch.nn.functional as F input = torch.randn(3,4) print(input) tensor([[-0.5526, -0.0194, 2.1469, -0.2567], [-0.3337, -0.9229, 0.0376, -0.0801], [ 1.4721, 0.1

  • pytorch中torch.max和Tensor.view函数用法详解

    torch.max() 1. torch.max()简单来说是返回一个tensor中的最大值. 例如: >>> si=torch.randn(4,5) >>> print(si) tensor([[ 1.1659, -1.5195, 0.0455, 1.7610, -0.2064], [-0.3443, 2.0483, 0.6303, 0.9475, 0.4364], [-1.5268, -1.0833, 1.6847, 0.0145, -0.2088], [-0.86

  • Pytorch中torch.nn.Softmax的dim参数用法说明

    Pytorch中torch.nn.Softmax的dim参数使用含义 涉及到多维tensor时,对softmax的参数dim总是很迷,下面用一个例子说明 import torch.nn as nn m = nn.Softmax(dim=0) n = nn.Softmax(dim=1) k = nn.Softmax(dim=2) input = torch.randn(2, 2, 3) print(input) print(m(input)) print(n(input)) print(k(inp

  • PyTorch中torch.utils.data.DataLoader简单介绍与使用方法

    目录 一.torch.utils.data.DataLoader 简介 二.实例 参考链接 总结 一.torch.utils.data.DataLoader 简介 作用:torch.utils.data.DataLoader 主要是对数据进行 batch 的划分. 数据加载器,结合了数据集和取样器,并且可以提供多个线程处理数据集. 在训练模型时使用到此函数,用来 把训练数据分成多个小组 ,此函数 每次抛出一组数据 .直至把所有的数据都抛出.就是做一个数据的初始化. 好处: 使用DataLoade

  • PyTorch中torch.tensor与torch.Tensor的区别详解

    PyTorch最近几年可谓大火.相比于TensorFlow,PyTorch对于Python初学者更为友好,更易上手. 众所周知,numpy作为Python中数据分析的专业第三方库,比Python自带的Math库速度更快.同样的,在PyTorch中,有一个类似于numpy的库,称为Tensor.Tensor自称为神经网络界的numpy. 一.numpy和Tensor二者对比 对比项 numpy Tensor 相同点 可以定义多维数组,进行切片.改变维度.数学运算等 可以定义多维数组,进行切片.改变

  • pytorch中的 .view()函数的用法介绍

    目录 一.普通用法(手动调整size) 二.特殊用法:参数-1(自动调整size) 一.普通用法 (手动调整size) view()相当于reshape.resize,重新调整Tensor的形状. import torch a1 = torch.arange(0,16) print(a1) # tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]) a2 = a1.view(8, 2) a3 = a1.vi

  • PyTorch中torch.nn.functional.cosine_similarity使用详解

    目录 概述 按照dim=0求余弦相似: 按照dim=1求余弦相似: 总结 概述 根据官网文档的描述,其中 dim表示沿着对应的维度计算余弦相似.那么怎么理解呢? 首先,先介绍下所谓的dim: a = torch.tensor([[ [1, 2], [3, 4] ], [ [5, 6], [7, 8] ] ], dtype=torch.float) print(a.shape) """ [ [ [1, 2], [3, 4] ], [ [5, 6], [7, 8] ] ] &qu

  • PyTorch中torch.manual_seed()的用法实例详解

    目录 一.torch.manual_seed(seed) 介绍 torch.manual_seed(seed) 功能描述 语法 参数 返回 二.类似函数的功能 三.实例 实例 1 :不设随机种子,生成随机数 实例 2 :设置随机种子,使得每次运行代码生成的随机数都一样 实例 3 :不同的随机种子生成不同的值 总结 一.torch.manual_seed(seed) 介绍 torch.manual_seed(seed) 功能描述 设置 CPU 生成随机数的 种子 ,方便下次复现实验结果. 为 CP

随机推荐