PyTorch中permute的基本用法示例
目录
- permute(dims)
- 附:permute(多维数组,[维数的组合])
- 总结
permute(dims)
将tensor的维度换位。
参数:参数是一系列的整数,代表原来张量的维度。比如三维就有0,1,2这些dimension。
例:
import torch import numpy as np a=np.array([[[1,2,3],[4,5,6]]]) unpermuted=torch.tensor(a) print(unpermuted.size()) # ——> torch.Size([1, 2, 3]) permuted=unpermuted.permute(2,0,1) print(permuted.size()) # ——> torch.Size([3, 1, 2])
再比如图片img的size比如是(28,28,3)就可以利用img.permute(2,0,1)得到一个size为(3,28,28)的tensor。
利用这个函数permute(0,2,1)可以把Tensor([[[1,2,3],[4,5,6]]]) 转换成
tensor([[[1., 4.], [2., 5.], [3., 6.]]])
如果使用view,可以得到
tensor([[[1., 2.], [3., 4.], [5., 6.]]])
关于view的用法:参见PyTorch中view的用法
附:permute(多维数组,[维数的组合])
比如:
a=rand(2,3,4); %这是一个三维数组,各维的长度分别为:2,3,4
%现在交换第一维和第二维:
permute(A,[2,1,3]) %变成3*2*4的矩阵
import torch import numpy as np a=np.array([[[1,2,3],[4,5,6]]]) unpermuted=torch.tensor(a) print(unpermuted.size()) # ——> torch.Size([1, 2, 3]) tensor([[[1., 4.], [2., 5.], [3., 6.]]]) permuted=unpermuted.permute(2,0,1) print(permuted.size()) # ——> torch.Size([3, 1, 2]) tensor([[[1., 2.], [3., 4.], [5., 6.]]])
总结
到此这篇关于PyTorch中permute的基本用法的文章就介绍到这了,更多相关PyTorch permute的用法内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!
赞 (0)