Python torch.flatten()函数案例详解

先看函数参数:

torch.flatten(input, start_dim=0, end_dim=-1)

input: 一个 tensor,即要被“推平”的 tensor。

start_dim: “推平”的起始维度。

end_dim: “推平”的结束维度。

首先如果按照 start_dim 和 end_dim 的默认值,那么这个函数会把 input 推平成一个 shape 为 [n][n] 的tensor,其中 nn 即 input 中元素个数。

如果我们要自己设定起始维度和结束维度呢?

我们要先来看一下 tensor 中的 shape 是怎么样的:

t = torch.tensor([[[1, 2, 2, 1],
                   [3, 4, 4, 3],
                   [1, 2, 3, 4]],
                  [[5, 6, 6, 5],
                   [7, 8, 8, 7],
                   [5, 6, 7, 8]]])
print(t, t.shape)

运行结果:

tensor([[[1, 2, 2, 1],
         [3, 4, 4, 3],
         [1, 2, 3, 4]],

        [[5, 6, 6, 5],
         [7, 8, 8, 7],
         [5, 6, 7, 8]]])
torch.Size([2, 3, 4])

我们可以看到,最外层的方括号内含两个元素,因此 shape 的第一个值是 2;类似地,第二层方括号里面含三个元素,shape 的第二个值就是 3;最内层方括号里含四个元素,shape 的第二个值就是 4。

示例代码:

x = torch.flatten(t, start_dim=1)
print(x, x.shape)

y = torch.flatten(t, start_dim=0, end_dim=1)
print(y, y.shape)

运行结果:

tensor([[1, 2, 2, 1, 3, 4, 4, 3, 1, 2, 3, 4],
        [5, 6, 6, 5, 7, 8, 8, 7, 5, 6, 7, 8]])
torch.Size([2, 12])

tensor([[1, 2, 2, 1],
        [3, 4, 4, 3],
        [1, 2, 3, 4],
        [5, 6, 6, 5],
        [7, 8, 8, 7],
        [5, 6, 7, 8]])
torch.Size([6, 4])

可以看到,当 start_dim = 11 而 end_dim = −1−1 时,它把第 11 个维度到最后一个维度全部推平合并了。而当 start_dim = 00 而 end_dim = 11 时,它把第 00 个维度到第 11 个维度全部推平合并了。pytorch中的 torch.nn.Flatten 类和 torch.Tensor.flatten 方法其实都是基于上面的 torch.flatten 函数实现的。

到此这篇关于Python torch.flatten()函数案例详解的文章就介绍到这了,更多相关Python torch.flatten()函数内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • python PyTorch参数初始化和Finetune

    前言 这篇文章算是论坛PyTorch Forums关于参数初始化和finetune的总结,也是我在写代码中用的算是"最佳实践"吧.最后希望大家没事多逛逛论坛,有很多高质量的回答. 参数初始化 参数的初始化其实就是对参数赋值.而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了data,grad等借口,这就意味着我们可以直接对这些参数进行操作赋值了.这就是PyTorch简洁高效所在. 所以我们可以进行如下操作进行初始化,当然其实有其他的方法,但是这种方法

  • 基于python及pytorch中乘法的使用详解

    numpy中的乘法 A = np.array([[1, 2, 3], [2, 3, 4]]) B = np.array([[1, 0, 1], [2, 1, -1]]) C = np.array([[1, 0], [0, 1], [-1, 0]]) A * B : # 对应位置相乘 np.array([[ 1, 0, 3], [ 4, 3, -4]]) A.dot(B) : # 矩阵乘法 ValueError: shapes (2,3) and (2,3) not aligned: 3 (dim

  • python 如何查看pytorch版本

    看代码吧~ import torch print(torch.__version__) 补充:pytorch不同版本安装以及版本查看 一:基于conda安装 conda create --name pytorch_learn python=3.6.7#创建一个名为pytorch_learn的环境 source activate pytorch_learn #进入环境 conda install pytorch=0.3.1 cuda80 -c soumith #安装pytorch0.3.1+ cu

  • 浅谈pytorch、cuda、python的版本对齐问题

    在使用深度学习模型训练的过程中,工具的准备也算是一个良好的开端吧.熟话说完事开头难,磨刀不误砍柴工,先把前期的问题搞通了,能为后期节省不少精力. 以pytorch工具为例: pytorch版本为1.0.1,自带python版本为3.6.2 服务器上GPU的CUDA_VERSION=9000 注意:由于GPU上的CUDA_VERSION为9000,所以至少要安装cuda版本>=9.0,虽然cuda=7.0~8.0也能跑,但是一开始可能会遇到各种各样的问题,本人cuda版本为10.0,安装cuda的

  • 简述python&pytorch 随机种子的实现

    随机数广泛应用在科学研究, 但是计算机无法产生真正的随机数, 一般成为伪随机数. 它的产生过程: 给定一个随机种子(一个正整数), 根据随机算法和种子产生随机序列. 给定相同的随机种子, 计算机产生的随机数列是一样的(这也许是伪随机的原因). 随机种子是什么? 随机种子是针对随机方法而言的. 随机方法:常见的随机方法有 生成随机数,以及其他的像 随机排序 之类的,后者本质上也是基于生成随机数来实现的.在深度学习中,比较常用的随机方法的应用有:网络的随机初始化,训练集的随机打乱等. 随机种子的取值

  • Python torch.flatten()函数案例详解

    先看函数参数: torch.flatten(input, start_dim=0, end_dim=-1) input: 一个 tensor,即要被"推平"的 tensor. start_dim: "推平"的起始维度. end_dim: "推平"的结束维度. 首先如果按照 start_dim 和 end_dim 的默认值,那么这个函数会把 input 推平成一个 shape 为 [n][n] 的tensor,其中 nn 即 input 中元素个数

  • Python之基础函数案例详解

    函数就是把具有独立功能的代码块封装成一个小模块,可以直接调用,从而提高代码的编写效率以及重用性, 需要注意的是, 函数需要被调用才会执行, 而调用函数需要根据函数名调用  函数的定义格式: def 函数名(): 函数代码 使用当前文件的函数 我们直接定义一个函数然后运行程序, 函数并不会被调用 def hello(): print('hello') 想要函数被执行, 需要使用函数名来调用函数 # 定义函数 def hello(): print('hello') # 调用函数 hello()  需

  • Python ord函数()案例详解

    python中ord函数 Python ord()函数 (Python ord() function) ord() function is a library function in Python, it is used to get number value from given character value, it accepts a character and returns an integer i.e. it is used to convert a character to an

  • Python字典中items()函数案例详解

    Python3:字典中的items()函数 一.Python2.x中items():   和之前一样,本渣渣先贴出来python中help的帮助信息: >>> help(dict.items) Help on method_descriptor: items(...) D.items() -> list of D's (key, value) pairs, as 2-tuples >>> help(dict.iteritems) Help on method_de

  • Python中return用法案例详解

    python中return的用法 1.return语句就是把执行结果返回到调用的地方,并把程序的控制权一起返回 程序运行到所遇到的第一个return即返回(退出def块),不会再运行第二个return. 例如: def haha(x,y): if x==y: return x,y print(haha(1,1)) 已改正: 结果:这种return传参会返回元组(1, 1) 2.但是也并不意味着一个函数体中只能有一个return 语句,例如: def test_return(x): if x >

  • Python 概率生成问题案例详解

    概率生成问题 有一枚不均匀的硬币,要求产生均匀的概率分布 有一枚均匀的硬币,要求产生不均匀的概率分布,如 0.25 和 0.75 利用 Rand7() 实现 Rand10() 不均匀硬币 产生等概率 现有一枚不均匀的硬币 coin(),能够返回 0.1 两个值,其概率分别为 0.6.0.4.要求使用这枚硬币,产生均匀的概率分布.即编写一个函数 coin_new() 使得它返回 0.1 的概率均为 0.5. # 不均匀硬币,返回 0.1 的概率分别为 0.6.0.4 def coin(): ret

  • Python 实现静态链表案例详解

    静态链表和动态链表区别 静态链表和动态链表的共同点是,数据之间"一对一"的逻辑关系都是依靠指针(静态链表中称"游标")来维持. 静态链表 使用静态链表存储数据,需要预先申请足够大的一整块内存空间,也就是说,静态链表存储数据元素的个数从其创建的那一刻就已经确定,后期无法更改. 不仅如此,静态链表是在固定大小的存储空间内随机存储各个数据元素,这就造成了静态链表中需要使用另一条链表(通常称为"备用链表")来记录空间存储空间的位置,以便后期分配给新添加元

  • Python之re模块案例详解

    一.正则表达式   re模块是python独有的匹配字符串的模块,该模块中提供的很多功能是基于正则表达式实现的,而正则表达式是对字符串进行模糊匹配,提取自己需要的字符串部分,他对所有的语言都通用.注意: re模块是python独有的 正则表达式所有编程语言都可以使用 re模块.正则表达式是对字符串进行操作 因为,re模块中的方法大都借助于正则表达式,故先学习正则表达式. (一)常用正则  1.字符组 在同一个位置可能出现的各种字符组成了一个字符组,在正则表达式中用[]表示 正则 待匹配字符 匹配

  • Python自动化办公实战案例详解(Word、Excel、Pdf、Email邮件)

    目录 背景 实现过程 1)替换Word模板生成对应邀请函 2)将Word邀请函转化为Pdf格式 4)自动发送邮件 5)完整代码 总结 背景 想象一下,现在你有一份Word邀请函模板,然后你有一份客户列表,上面有客户的姓名.联系方式.邮箱等基本信息,然后你的老板现在需要替换邀请函模板中的姓名,然后将Word邀请函模板生成Pdf格式,之后编辑统一的邀请话术(邮件正文),再依次发送邀请函附件到客户邮箱,你会怎么做? 正常情况下,我们肯定是复制粘贴Excel表格中的客户姓名,之后挨个Word文档进行替换

  • Python实现地图可视化案例详解

    目录 ​前言 一.pyecharts Map Geo Bmap 二.folium 结 语 ​前言 Python的地图可视化库很多,Matplotlib库虽然作图很强大,但只能做静态地图.而我今天要讲的是交互式地图库,分别为pyecharts.folium,掌握这两个库,基本可以解决你的地图可视化需求. 一.pyecharts 首先,必须说说强大的pyecharts库,简单易用又酷炫,几乎可以制作任何图表.pyecharts有v0.5和v1两个版本,两者不兼容,最新的v1版本开始支持链式调用,采用

随机推荐