Pytorch中torch.flatten()和torch.nn.Flatten()实例详解

torch.flatten(x)等于torch.flatten(x,0)默认将张量拉成一维的向量,也就是说从第一维开始平坦化,torch.flatten(x,1)代表从第二维开始平坦化。

import torch
x=torch.randn(2,4,2)
print(x)

z=torch.flatten(x)
print(z)

w=torch.flatten(x,1)
print(w)

输出为:
tensor([[[-0.9814,  0.8251],
         [ 0.8197, -1.0426],
         [-0.8185, -1.3367],
         [-0.6293,  0.6714]],

        [[-0.5973, -0.0944],
         [ 0.3720,  0.0672],
         [ 0.2681,  1.8025],
         [-0.0606,  0.4855]]])

tensor([-0.9814,  0.8251,  0.8197, -1.0426, -0.8185, -1.3367, -0.6293,  0.6714,
        -0.5973, -0.0944,  0.3720,  0.0672,  0.2681,  1.8025, -0.0606,  0.4855])

tensor([[-0.9814,  0.8251,  0.8197, -1.0426, -0.8185, -1.3367, -0.6293,  0.6714]
,
        [-0.5973, -0.0944,  0.3720,  0.0672,  0.2681,  1.8025, -0.0606,  0.4855]
])

torch.flatten(x,0,1)代表在第一维和第二维之间平坦化

import torch
x=torch.randn(2,4,2)
print(x)

w=torch.flatten(x,0,1) #第一维长度2,第二维长度为4,平坦化后长度为2*4
print(w.shape)

print(w)

输出为:
tensor([[[-0.5523, -0.1132],
         [-2.2659, -0.0316],
         [ 0.1372, -0.8486],
         [-0.3593, -0.2622]],

        [[-0.9130,  1.0038],
         [-0.3996,  0.4934],
         [ 1.7269,  0.8215],
         [ 0.1207, -0.9590]]])

torch.Size([8, 2])

tensor([[-0.5523, -0.1132],
        [-2.2659, -0.0316],
        [ 0.1372, -0.8486],
        [-0.3593, -0.2622],
        [-0.9130,  1.0038],
        [-0.3996,  0.4934],
        [ 1.7269,  0.8215],
        [ 0.1207, -0.9590]])

对于torch.nn.Flatten(),因为其被用在神经网络中,输入为一批数据,第一维为batch,通常要把一个数据拉成一维,而不是将一批数据拉为一维。所以torch.nn.Flatten()默认从第二维开始平坦化。

import torch
#随机32个通道为1的5*5的图
x=torch.randn(32,1,5,5)

model=torch.nn.Sequential(
    #输入通道为1,输出通道为6,3*3的卷积核,步长为1,padding=1
    torch.nn.Conv2d(1,6,3,1,1),
    torch.nn.Flatten()
)
output=model(x)
print(output.shape)  # 6*(7-3+1)*(7-3+1)

输出为:

torch.Size([32, 150])

总结

到此这篇关于Pytorch中torch.flatten()和torch.nn.Flatten()的文章就介绍到这了,更多相关Pytorch torch.flatten()和torch.nn.Flatten()内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • Python torch.flatten()函数案例详解

    先看函数参数: torch.flatten(input, start_dim=0, end_dim=-1) input: 一个 tensor,即要被"推平"的 tensor. start_dim: "推平"的起始维度. end_dim: "推平"的结束维度. 首先如果按照 start_dim 和 end_dim 的默认值,那么这个函数会把 input 推平成一个 shape 为 [n][n] 的tensor,其中 nn 即 input 中元素个数

  • PyTorch中torch.nn.Linear实例详解

    目录 前言 1. nn.Linear的原理: 2. nn.Linear的使用: 3. nn.Linear的源码定义: 补充:许多细节需要声明 总结 前言 在学习transformer时,遇到过非常频繁的nn.Linear()函数,这里对nn.Linear进行一个详解.参考:https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html 1. nn.Linear的原理: 从名称就可以看出来,nn.Linear表示的是线性变

  • pytorch中的卷积和池化计算方式详解

    TensorFlow里面的padding只有两个选项也就是valid和same pytorch里面的padding么有这两个选项,它是数字0,1,2,3等等,默认是0 所以输出的h和w的计算方式也是稍微有一点点不同的:tf中的输出大小是和原来的大小成倍数关系,不能任意的输出大小:而nn输出大小可以通过padding进行改变 nn里面的卷积操作或者是池化操作的H和W部分都是一样的计算公式:H和W的计算 class torch.nn.MaxPool2d(kernel_size, stride=Non

  • 关于pytorch中全连接神经网络搭建两种模式详解

    pytorch搭建神经网络是很简单明了的,这里介绍两种自己常用的搭建模式: import torch import torch.nn as nn first: class NN(nn.Module): def __init__(self): super(NN,self).__init__() self.model=nn.Sequential( nn.Linear(30,40), nn.ReLU(), nn.Linear(40,60), nn.Tanh(), nn.Linear(60,10), n

  • PyTorch中 tensor.detach() 和 tensor.data 的区别详解

    PyTorch0.4中,.data 仍保留,但建议使用 .detach(), 区别在于 .data 返回和 x 的相同数据 tensor, 但不会加入到x的计算历史里,且require s_grad = False, 这样有些时候是不安全的, 因为 x.data 不能被 autograd 追踪求微分 . .detach() 返回相同数据的 tensor ,且 requires_grad=False ,但能通过 in-place 操作报告给 autograd 在进行反向传播的时候. 举例: ten

  • C++中回调函数及函数指针的实例详解

    C++中回调函数及函数指针的实例详解 如何获取到类中函数指针 实现代码: //A类与B类的定义 class A { public: void Test() { cout << "A::Test()" << endl; } }; class B : public A { public: void Test() { cout << "B::Test()" << endl; } }; //定义类的成员函数指针 typedef

  • thinkphp中的多表关联查询的实例详解

    thinkphp中的多表关联查询的实例详解 在进行后端管理系统的编程的时候一般会使用框架来进行页面的快速搭建,我最近使用比较多的就是thinkphp框架,thinkphp框架的应用其实就是把前端和后端进行分割管理,前端用户登录查询系统放在thinkphp中的home文件夹中进行管理,后端管理系统放在thinkphp中的admin文件夹中进行管理.对了,在使用thinkphp框架的时候是是要用到mvc架构的,mvc架构就是model(数据模型).view(视图).controller(控制器)的结

  • iOS开发中以application/json上传文件实例详解

    本文通过实例代码给大家讲解iOS中以application/json上传文件的形式,具体内容详情大家参考下本文. 在和sever后台交互的过程中.有时候.他们需要我们iOS开发者以"application/json"形式上传. NSString *accessUrl = [NSString stringWithFormat:@"%@/xxx",@"https://www.xxxxx.com:xxxx"]; NSMutableURLRequest

  • oracle中创建序列及序列补零实例详解

    oracle中创建序列及序列补零实例详解 我们经常会在在DB中创建序列: -- Create sequence create sequence COMMON_SEQ minvalue 1 maxvalue 999999999 start with 1 increment by 1 cache 20 cycle; 我们的序列的最小值是从1开始,但是我们想让这种顺序取出来的序列的位数都一样,按照最大数的位数来算,我们需要8位的序列,那么我们就需要在1的前面补上7个零,只需要用下面的方法即可完成 se

  • JSP 中Spring组合注解与元注解实例详解

    JSP 中Spring组合注解与元注解实例详解 摘要: 注解(Annotation),也叫元数据.一种代码级别的说明.它与类.接口.枚举是在同一个层次.它可以声明在包.类.字段.方法.局部变量.方法参数等的前面,用来对这些元素进行说明 1. 可以注解到别的注解上的注解称为元注解,被注解的注解称为组合注解,通过组合注解可以很好的简化好多重复性的注解操作 2. 示例组合注解 import org.springframework.context.annotation.ComponentScan; im

  • C++ 中字符串操作--宽窄字符转换的实例详解

    C++ 中字符串操作--宽窄字符转换的实例详解 MultiByteToWideChar int MultiByteToWideChar( _In_ UINT CodePage, _In_ DWORD dwFlags, _In_ LPCSTR lpMultiByteStr, _In_ int cbMultiByte, _Out_opt_ LPWSTR lpWideCharStr, _In_ int cchWideChar ); 参数描述: CodePage:常用CP_ACP.CP_UTF8 dwF

随机推荐