PyTorch中permute的用法详解

permute(dims)

将tensor的维度换位。

参数:参数是一系列的整数,代表原来张量的维度。比如三维就有0,1,2这些dimension。

例:

import torch
import numpy as np
a=np.array([[[1,2,3],[4,5,6]]])
unpermuted=torch.tensor(a)
print(unpermuted.size()) # ——> torch.Size([1, 2, 3])
permuted=unpermuted.permute(2,0,1)
print(permuted.size()) # ——> torch.Size([3, 1, 2])

再比如图片img的size比如是(28,28,3)就可以利用img.permute(2,0,1)得到一个size为(3,28,28)的tensor。

利用这个函数permute(1,3,2)可以把Tensor([[[1,2,3],[4,5,6]]]) 转换成

tensor([[[1., 4.],
[2., 5.],
[3., 6.]]])

如果使用view(1,3,2),可以得到

tensor([[[1., 2.],
[3., 4.],
[5., 6.]]])

以上这篇PyTorch中permute的用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • pytorch中的卷积和池化计算方式详解

    TensorFlow里面的padding只有两个选项也就是valid和same pytorch里面的padding么有这两个选项,它是数字0,1,2,3等等,默认是0 所以输出的h和w的计算方式也是稍微有一点点不同的:tf中的输出大小是和原来的大小成倍数关系,不能任意的输出大小:而nn输出大小可以通过padding进行改变 nn里面的卷积操作或者是池化操作的H和W部分都是一样的计算公式:H和W的计算 class torch.nn.MaxPool2d(kernel_size, stride=Non

  • pytorch AvgPool2d函数使用详解

    我就废话不多说了,直接上代码吧! import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable import numpy as np input = Variable(torch.Tensor([[[1, 3, 3, 4, 5, 6, 7], [1, 2, 3, 4, 5, 6, 7]], [[1, 3, 3, 4, 5, 6, 7], [1, 2, 3

  • Pytorch实现各种2d卷积示例

    普通卷积 使用nn.Conv2d(),一般还会接上BN和ReLu 参数量NNCin*Cout+Cout(如果有bias,相对来说表示对参数量影响很小,所以后面不考虑) class ConvBNReLU(nn.Module): def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): super(ConvBNReLU, self).__init__() self.op = nn.Sequential( n

  • Pytorch 的损失函数Loss function使用详解

    1.损失函数 损失函数,又叫目标函数,是编译一个神经网络模型必须的两个要素之一.另一个必不可少的要素是优化器. 损失函数是指用于计算标签值和预测值之间差异的函数,在机器学习过程中,有多种损失函数可供选择,典型的有距离向量,绝对值向量等. 损失Loss必须是标量,因为向量无法比较大小(向量本身需要通过范数等标量来比较). 损失函数一般分为4种,平方损失函数,对数损失函数,HingeLoss 0-1 损失函数,绝对值损失函数. 我们先定义两个二维数组,然后用不同的损失函数计算其损失值. import

  • PyTorch中permute的用法详解

    permute(dims) 将tensor的维度换位. 参数:参数是一系列的整数,代表原来张量的维度.比如三维就有0,1,2这些dimension. 例: import torch import numpy as np a=np.array([[[1,2,3],[4,5,6]]]) unpermuted=torch.tensor(a) print(unpermuted.size()) # --> torch.Size([1, 2, 3]) permuted=unpermuted.permute(

  • Pytorch 中retain_graph的用法详解

    用法分析 在查看SRGAN源码时有如下损失函数,其中设置了retain_graph=True,其作用是什么? ############################ # (1) Update D network: maximize D(x)-1-D(G(z)) ########################### real_img = Variable(target) if torch.cuda.is_available(): real_img = real_img.cuda() z = V

  • pytorch中index_select()的用法详解

    pytorch中index_select()的用法 index_select(input, dim, index) 功能:在指定的维度dim上选取数据,不如选取某些行,列 参数介绍 第一个参数input是要索引查找的对象 第二个参数dim是要查找的维度,因为通常情况下我们使用的都是二维张量,所以可以简单的记忆: 0代表行,1代表列 第三个参数index是你要索引的序列,它是一个tensor对象 刚开始学习pytorch,遇到了index_select(),一开始不太明白几个参数的意思,后来查了一

  • pytorch中的dataset用法详解

    目录 1.torch.utils.data 里面的dataset使用方法 2.torchvision.datasets的使用方法 用法1:使用官方数据集 用法2:ImageFolder通用的自己数据集加载器 1.torch.utils.data 里面的dataset使用方法 当我们继承了一个 Dataset类之后,我们需要重写 len 方法,该方法提供了dataset的大小: getitem 方法, 该方法支持从 0 到 len(self)的索引 from torch.utils.data im

  • pytorch中permute()函数用法实例详解

    目录 前言 三维情况 变化一:不改变任何参数 变化二:1与2交换 变化三:0与1交换 变化四:0与2交换 变化五:0与1交换,1与2交换 变化六:0与1交换,0与2交换 总结 前言 本文只讨论二维三维中的permute用法 最近的Attention学习中的一个permute函数让我不理解 这个光说太抽象 我就结合代码与图片解释一下 首先创建一个三维数组小实例 import torch x = torch.linspace(1, 30, steps=30).view(3,2,5) # 设置一个三维

  • Pytorch mask_select 函数的用法详解

    非常简单的函数,但是官网的介绍令人(令我)迷惑,所以稍加解释. mask_select会将满足mask(掩码.遮罩等等,随便翻译)的指示,将满足条件的点选出来. 根据掩码张量mask中的二元值,取输入张量中的指定项( mask为一个 ByteTensor),将取值返回到一个新的1D张量, 张量 mask须跟input张量有相同数量的元素数目,但形状或维度不需要相同 x = torch.randn(3, 4) x 1.2045 2.4084 0.4001 1.1372 0.5596 1.5677

  • 基于C++中setiosflags()的用法详解

    cout<<setiosflags(ios::fixed)<<setiosflags(ios::right)<<setprecision(2); setiosflags 是包含在命名空间iomanip 中的C++ 操作符,该操作符的作用是执行由有参数指定区域内的动作:   iso::fixed 是操作符setiosflags 的参数之一,该参数指定的动作是以带小数点的形式表示浮点数,并且在允许的精度范围内尽可能的把数字移向小数点右侧:   iso::right 也是se

  • Angular 中 select指令用法详解

    最近在angular中使用select指令时,出现了很多问题,搞得很郁闷.查看了很多资料后,发现select指令并不简单,决定总结一下. select用法: <select ng-model="" [name=""] [required=""] [ng-required=""] [ng-options=""]> </select> 属性说明: 发现并没有ng-change属性 ng-

  • java 中 ChannelHandler的用法详解

    java 中 ChannelHandler的用法详解 前言: ChannelHandler处理一个I/O event或者拦截一个I/O操作,在它的ChannelPipeline中将其递交给相邻的下一个handler. 通过继承ChannelHandlerAdapter来代替 因为这个接口有许多的方法需要实现,你或许希望通过继承ChannelHandlerAdapter来代替. context对象 一个ChannelHandler和一个ChannelHandlerContext对象一起被提供.一个

  • Java中isAssignableFrom的用法详解

    class1.isAssignableFrom(class2) 判定此 Class 对象所表示的类或接口与指定的 Class 参数所表示的类或接口是否相同,或是否是其超类或超接口.如果是则返回 true:否则返回 false.如果该 Class 表示一个基本类型,且指定的 Class 参数正是该 Class 对象,则该方法返回 true:否则返回 false. 1. class2是不是class1的子类或者子接口 2. Object是所有类的父类 一个例子搞定: package com.auuz

随机推荐