pytorch中nn.Flatten()函数详解及示例

torch.nn.Flatten(start_dim=1, end_dim=- 1)

作用:将连续的维度范围展平为张量。 经常在nn.Sequential()中出现,一般写在某个神经网络模型之后,用于对神经网络模型的输出进行处理,得到tensor类型的数据。

有俩个参数,start_dim和end_dim,分别表示开始的维度和终止的维度,默认值分别是1和-1,其中1表示第一维度,-1表示最后的维度。结合起来看意思就是从第一维度到最后一个维度全部给展平为张量。(注意:数据的维度是从0开始的,也就是存在第0维度,第一维度并不是真正意义上的第一个)

同理,如果我这么写:

self.flat = nn.Flatten(start_dim=2, end_dim=3)

那么意思就是从第二维度开始,到第三维度全部给展平,也就是将2、3两个维度展平。

官网给出的示例:

input = torch.randn(32, 1, 5, 5)
# With default parameters
m = nn.Flatten()
output = m(input)
output.size()
#torch.Size([32, 25])
# With non-default parameters
m = nn.Flatten(0, 2)
output = m(input)
output.size()
#torch.Size([160, 5])

#开头的代码是注释

整段代码的意思是:给定一个维度为(32,1,5,5)的随机数据。

1.先使用一次nn.Flatten(),使用默认参数:

m = nn.Flatten()

也就是说从第一维度展平到最后一个维度,数据的维度是从0开始的,第一维度实际上是数据的第二个位置代表的维度,也就是样例中的1。

因此进行展平后的结果也就是[32,1×5×5][32,25]

2.接着再使用一次指定参数的nn.Flatten(),即

m = nn.Flatten(0, 2)

也就是说从第0维度展平到第2维度,0~2,对应的也就是前三个维度。

因此结果就是[32×1×5,5][160,5]

因此进行展平后的结果也就是[32,1*5*5][32,25]

示例1

卷积公式

import torch
import torch.nn as nn
input = torch.randn(32, 1, 5, 5)
m = nn.Sequential(
    nn.Conv2d(1, 32, 5, 1, 1),  # 通过卷积,得到torch.size([32, 32, 3, 3]
    nn.Flatten())

output = m(input)
print(output.size())

>> torch.Size([32, 288])

示例2

import torch
import torch.nn as nn
input = torch.randn(32, 1, 5, 5)
m = nn.Sequential(
    nn.Conv2d(1, 32, 5, 1, 1),  # 通过卷积,得到torch.size([32, 32, 3, 3]
    nn.Flatten(start_dim=0))

output = m(input)
print(output.size())

>>torch.Size([9216])

总结

到此这篇关于pytorch中nn.Flatten()函数详解的文章就介绍到这了,更多相关pytorch nn.Flatten()函数详解内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • Pytorch中torch.flatten()和torch.nn.Flatten()实例详解

    torch.flatten(x)等于torch.flatten(x,0)默认将张量拉成一维的向量,也就是说从第一维开始平坦化,torch.flatten(x,1)代表从第二维开始平坦化. import torch x=torch.randn(2,4,2) print(x) z=torch.flatten(x) print(z) w=torch.flatten(x,1) print(w) 输出为: tensor([[[-0.9814, 0.8251], [ 0.8197, -1.0426], [-

  • pytorch中nn.Flatten()函数详解及示例

    torch.nn.Flatten(start_dim=1, end_dim=- 1) 作用:将连续的维度范围展平为张量. 经常在nn.Sequential()中出现,一般写在某个神经网络模型之后,用于对神经网络模型的输出进行处理,得到tensor类型的数据. 有俩个参数,start_dim和end_dim,分别表示开始的维度和终止的维度,默认值分别是1和-1,其中1表示第一维度,-1表示最后的维度.结合起来看意思就是从第一维度到最后一个维度全部给展平为张量.(注意:数据的维度是从0开始的,也就是

  • pytorch中的自定义数据处理详解

    pytorch在数据中采用Dataset的数据保存方式,需要继承data.Dataset类,如果需要自己处理数据的话,需要实现两个基本方法. :.getitem:返回一条数据或者一个样本,obj[index] = obj.getitem(index). :.len:返回样本的数量 . len(obj) = obj.len(). Dataset 在data里,调用的时候使用 from torch.utils import data import os from PIL import Image 数

  • PyTorch中permute的用法详解

    permute(dims) 将tensor的维度换位. 参数:参数是一系列的整数,代表原来张量的维度.比如三维就有0,1,2这些dimension. 例: import torch import numpy as np a=np.array([[[1,2,3],[4,5,6]]]) unpermuted=torch.tensor(a) print(unpermuted.size()) # --> torch.Size([1, 2, 3]) permuted=unpermuted.permute(

  • Pytorch 中retain_graph的用法详解

    用法分析 在查看SRGAN源码时有如下损失函数,其中设置了retain_graph=True,其作用是什么? ############################ # (1) Update D network: maximize D(x)-1-D(G(z)) ########################### real_img = Variable(target) if torch.cuda.is_available(): real_img = real_img.cuda() z = V

  • pytorch中index_select()的用法详解

    pytorch中index_select()的用法 index_select(input, dim, index) 功能:在指定的维度dim上选取数据,不如选取某些行,列 参数介绍 第一个参数input是要索引查找的对象 第二个参数dim是要查找的维度,因为通常情况下我们使用的都是二维张量,所以可以简单的记忆: 0代表行,1代表列 第三个参数index是你要索引的序列,它是一个tensor对象 刚开始学习pytorch,遇到了index_select(),一开始不太明白几个参数的意思,后来查了一

  • javascript中Array()数组函数详解

    在程序语言中数组的重要性不言而喻,JavaScript中数组也是最常使用的对象之一,数组是值的有序集合,由于弱类型的原因,JavaScript中数组十分灵活.强大,不像是Java等强类型高级语言数组只能存放同一类型或其子类型元素,JavaScript在同一个数组中可以存放多种类型的元素,而且是长度也是可以动态调整的,可以随着数据增加或减少自动对数组长度做更改. Array()是一个用来构建数组的内建构造器函数.数组主要由如下三种创建方式: array = new Array() array =

  • COM组件中调用JavaScript函数详解及实例

    COM组件中调用JavaScript函数详解及实例 要求是很简单的,即有COM组件A在IE中运行,使用JavaScript(JS)调用A的方法longCalc(),该方法是一个耗时的操作,要求通知IE当前的进度.这就要求使用回调函数,设其名称为scriptCallbackFunc.实现这个技术很简单: 1 .组件方(C++) 组件A 的方法在IDL中定义: [id(2)] HRESULT longCalc([in] DOUBLE v1, [in] DOUBLE v2, [in, optional

  • 对Python3中的input函数详解

    下面介绍python3中的input函数及其在python2及pyhton3中的不同. python3中的ininput函数,首先利用help(input)函数查看函数信息: 以上信息说明input函数在python中是一个内建函数,其从标准输入中读入一个字符串,并自动忽略换行符. 也就是说所有形式的输入按字符串处理,如果想要得到其他类型的数据进行强制类型转化.默认情况下没有 提示字符串(prompt  string),在给定提示字符串下,会在读入标准输入前标准输出提示字符串.如果遇 文件结束符

  • 对TensorFlow中的variables_to_restore函数详解

    variables_to_restore函数,是TensorFlow为滑动平均值提供.之前,也介绍过通过使用滑动平均值可以让神经网络模型更加的健壮.我们也知道,其实在TensorFlow中,变量的滑动平均值都是由影子变量所维护的,如果你想要获取变量的滑动平均值需要获取的是影子变量而不是变量本身. 1.滑动平均值模型文件的保存 import tensorflow as tf if __name__ == "__main__": v = tf.Variable(0.,name="

  • 对Tensorflow中的矩阵运算函数详解

    tf.diag(diagonal,name=None) #生成对角矩阵 import tensorflowas tf; diagonal=[1,1,1,1] with tf.Session() as sess: print(sess.run(tf.diag(diagonal))) #输出的结果为[[1 0 0 0] [0 1 0 0] [0 0 1 0] [0 0 0 1]] tf.diag_part(input,name=None) #功能与tf.diag函数相反,返回对角阵的对角元素 imp

随机推荐