PyTorch中torch.tensor与torch.Tensor的区别详解

PyTorch最近几年可谓大火。相比于TensorFlow,PyTorch对于Python初学者更为友好,更易上手。

众所周知,numpy作为Python中数据分析的专业第三方库,比Python自带的Math库速度更快。同样的,在PyTorch中,有一个类似于numpy的库,称为Tensor。Tensor自称为神经网络界的numpy。

一、numpy和Tensor二者对比

对比项 numpy Tensor
相同点 可以定义多维数组,进行切片、改变维度、数学运算等 可以定义多维数组,进行切片、改变维度、数学运算等
不同点
1、产生的数组类型为numpy.ndarray;

2、会将ndarray放入CPU中进行运算;

3、导入方式为import numpy as np,后续通过np.array([1,2])建立数组;

4、numpy中没有x.type()的用法,只能使用type(x)。


1、产生的数组类型为torch.Tensor;

2、会将tensor放入GPU中进行加速运算(如果有GPU);

3、导入方式为import torch,后续通过torch.tensor([1,2])或torch.Tensor([1,2])建立数组;

4、Tensor中查看数组类型既可以使用type(x),也可以使用x.type()。但是更加推荐采用x.type(),具体原因详见下文。

举例(以下代码均在Jupyter Notebook上运行且通过):

numpy:

import numpy as np
x = np.array([1,2])
#之所以这么写,是为了告诉大家,在Jupyter Notebook中,是否带有print()函数打印出来的效果是不一样的~
x       #array([1, 2])
print(x)     #[1 2]
type(x)     #numpy.ndarray
print(type(x))   #<class 'numpy.ndarray'>
#注意:numpy中没有x.type()的用法,只能使用type(x)!!!

Tensor:

import torch    #注意,这里是import torch,不是import Tensor!!!
x = torch.tensor([1,2])
x       #tensor([1, 2])
print(x)     #tensor([1, 2]),注意,这里与numpy就不一样了!

type(x)     #torch.Tensor
print(type(x))    #<class 'torch.Tensor'>
x.type()     #'torch.LongTensor',注意:numpy中不可以这么写,会报错!!!
print(x.type())   #torch.LongTensor,注意:numpy中不可以这么写,会报错!!!

numpy与Tensor在使用上还有其他差别。由于不是本文的重点,故暂不详述。后续可能会更新~

二、torch.tensor与torch.Tensor的区别

细心的读者可能注意到了,通过Tensor建立数组有torch.tensor([1,2])或torch.Tensor([1,2])两种方式。那么,这两种方式有什么区别呢?

(1)torch.tensor是从数据中推断数据类型,而torch.Tensor是torch.empty(会随机产生垃圾数组,详见实例)和torch.tensor之间的一种混合。但是,当传入数据时,torch.Tensor使用全局默认dtype(FloatTensor);

(2)torch.tensor(1)返回一个固定值1,而torch.Tensor(1)返回一个大小为1的张量,它是初始化的随机值。

import torch    #注意,这里是import torch,不是import Tensor!!!

x = torch.tensor([1,2])

x       #tensor([1, 2])
print(x)     #tensor([1, 2]),注意,这里与numpy就不一样了!
type(x)     #torch.Tensor
print(type(x))    #<class 'torch.Tensor'>
x.type()     #'torch.LongTensor',注意:numpy中不可以这么写,会报错!!!
print(x.type())   #torch.LongTensor,注意:numpy中不可以这么写,会报错!!!

y = torch.Tensor([1,2])

y       #tensor([1., 2.]),因为torch.Tensor使用全局默认dtype(FloatTensor)
print(y)     #tensor([1., 2.]),因为torch.Tensor使用全局默认dtype(FloatTensor)
type(y)     #torch.Tensor
print(type(y))    #<class 'torch.Tensor'>
y.type()     #'torch.FloatTensor',注意:这里就与上面不一样了!tensor->LongTensor,Tensor->FloatTensor!!!
print(y.type())   #torch.FloatTensor,注意:这里就与上面不一样了!tensor->LongTensor,Tensor->FloatTensor!!!

z = torch.empty([1,2]) 

z       #随机运行两次,结果不同:tensor([[0., 0.]]),tensor([[1.4013e-45, 0.0000e+00]])
print(z)     #随机运行两次,结果不同:tensor([[0., 0.]]),tensor([[1.4013e-45, 0.0000e+00]])
type(z)     #torch.Tensor
print(type(z))    #<class 'torch.Tensor'>
z.type()     #'torch.FloatTensor',注意:empty()默认为torch.FloatTensor而不是torch.LongTensor
print(z.type())   #torch.FloatTensor,注意:empty()默认为torch.FloatTensor而不是torch.LongTensor

#torch.tensor(1)、torch.Tensor(1)和torch.empty(1)的对比:
t1 = torch.tensor(1)
t2 = torch.Tensor(1)
t3 = torch.empty(1)

t1       #tensor(1)
print(t1)     #tensor(1)
type(t1)     #torch.Tensor
print(type(t1))   #<class 'torch.Tensor'>
t1.type()     #'torch.LongTenso'
print(t1.type())   #torch.LongTensor

t2       #随机运行两次,结果不同:tensor([2.8026e-45]),tensor([0.])
print(t2)     #随机运行两次,结果不同:tensor([2.8026e-45]),tensor([0.])
type(t2)     #torch.Tensor
print(type(t2))   #<class 'torch.Tensor'>
t2.type()     #'torch.FloatTensor'
print(t2.type())   #torch.FloatTensor

t3       #随机运行两次,结果不同:tensor([0.]),tensor([1.4013e-45])
print(t3)     #随机运行两次,结果不同:tensor([0.]),tensor([1.4013e-45])
type(t3)     #torch.Tensor
print(type(t3))   #<class 'torch.Tensor'>
t3.type()     #'torch.FloatTensor'
print(t3.type())   #torch.FloatTensor

上文提到过,对于Tensor,更推荐采用x.type()来查看数据类型。是因为x.type()的输出结果为'torch.LongTensor'或'torch.FloatTensor',可以看出两个数组的种类区别。而采用type(x),则清一色的输出结果都是torch.Tensor,无法体现类型区别。

PyTorch是个神奇的工具,其中的Tensor用法要远比numpy丰富。大家可以在练习中多多总结,逐渐提高~

到此这篇关于PyTorch中torch.tensor与torch.Tensor的区别详解的文章就介绍到这了,更多相关PyTorch中torch.tensor与torch.Tensor内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • pytorch查看torch.Tensor和model是否在CUDA上的实例

    今天训练faster R-CNN时,发现之前跑的很好的程序(是指在运行程序过程中,显卡利用率能够一直维持在70%以上),今天看的时候,显卡利用率很低,所以在想是不是我的训练数据torch.Tensor或者模型model没有加载到GPU上训练,于是查找如何查看tensor和model所在设备的命令. import torch import torchvision.models as models model=models.vgg11(pretrained=False) print(next(mod

  • pytorch中的nn.ZeroPad2d()零填充函数实例详解

    在卷积神经网络中,有使用设置padding的参数,配合卷积步长,可以使得卷积后的特征图尺寸大小不发生改变,那么在手动实现图片或特征图的边界零填充时,常用的函数是nn.ZeroPad2d(),可以指定tensor的四个方向上的填充,比如左边添加1dim.右边添加2dim.上边添加3dim.下边添加4dim,即指定paddin参数为(1,2,3,4),本文中代码设置的是(3,4,5,6)如下: import torch.nn as nn import cv2 import torchvision f

  • pytorch中的卷积和池化计算方式详解

    TensorFlow里面的padding只有两个选项也就是valid和same pytorch里面的padding么有这两个选项,它是数字0,1,2,3等等,默认是0 所以输出的h和w的计算方式也是稍微有一点点不同的:tf中的输出大小是和原来的大小成倍数关系,不能任意的输出大小:而nn输出大小可以通过padding进行改变 nn里面的卷积操作或者是池化操作的H和W部分都是一样的计算公式:H和W的计算 class torch.nn.MaxPool2d(kernel_size, stride=Non

  • 关于pytorch中全连接神经网络搭建两种模式详解

    pytorch搭建神经网络是很简单明了的,这里介绍两种自己常用的搭建模式: import torch import torch.nn as nn first: class NN(nn.Module): def __init__(self): super(NN,self).__init__() self.model=nn.Sequential( nn.Linear(30,40), nn.ReLU(), nn.Linear(40,60), nn.Tanh(), nn.Linear(60,10), n

  • pytorch:model.train和model.eval用法及区别详解

    使用PyTorch进行训练和测试时一定注意要把实例化的model指定train/eval,eval()时,框架会自动把BN和DropOut固定住,不会取平均,而是用训练好的值,不然的话,一旦test的batch_size过小,很容易就会被BN层导致生成图片颜色失真极大!!!!!! Class Inpaint_Network() ...... Model = Inpaint_Nerwoek() #train: Model.train(mode=True) ..... #test: Model.ev

  • MySQL中Decimal类型和Float Double的区别(详解)

    MySQL中存在float,double等非标准数据类型,也有decimal这种标准数据类型. 其区别在于,float,double等非标准类型,在DB中保存的是近似值,而Decimal则以字符串的形式保存数值. float,double类型是可以存浮点数(即小数类型),但是float有个坏处,当你给定的数据是整数的时候,那么它就以整数给你处理.这样我们在存取货币值的时候自然遇到问题,我的default值为:0.00而实际存储是0,同样我存取货币为12.00,实际存储是12. 幸好mysql提供

  • java 中同步方法和同步代码块的区别详解

    java 中同步方法和同步代码块的区别详解 在Java语言中,每一个对象有一把锁.线程可以使用synchronized关键字来获取对象上的锁.synchronized关键字可应用在方法级别(粗粒度锁)或者是代码块级别(细粒度锁). 问题的由来: 看到这样一个面试题: //下列两个方法有什么区别 public synchronized void method1(){} public void method2(){ synchronized (obj){} } synchronized用于解决同步问

  • include包含头文件的语句中,双引号和尖括号的区别(详解)

    #include <>格式:引用标准库头文件,编译器从标准库目录开始搜索 #incluce ""格式:引用非标准库的头文件,编译器从用户的工作目录开始搜索 预处理器发现 #include 指令后,就会寻找后跟的文件名并把这个文件的内容包含到当前文件中.被包含文件中的文本将替换源代码文件中的#include指令,就像你把被包含文件中的全部内容键入到源文件中的这个位置一样. #include 指令有两种使用形式 #include <stdio.h> 文件名放在尖括号

  • js删除数组中的元素delete和splice的区别详解

    例如有一个数组是 :var textArr = ['a','b','c','d']; 这时我想删除这个数组中的b元素: 方法一:delete 删除数组 delete textArr[1]  结果为: ["a",undefined,"c","d"] 只是被删除的元素变成了 undefined 其他的元素的键值还是不变. 方法二:aplice 删除数组 splice(index,len,[item]) 注释:该方法会改变原始数组. index:数组开

  • 对Django 中request.get和request.post的区别详解

    Django 中request.get和request.post的区别 POST和GET差异: POST和GET是HTTP协议定义的与服务器交互的方法.GET一般用于获取/查询资源信息,而POST一般用于更新资源信息.另外,还有PUT和DELETE方法. POST和GET都可以与服务器完成查,改,增,删操作. GET提交,请求的数据会附在URL之后,以?分割URL和传输数据,多个参数用&连接: POST提交,把提交的数据放置在HTTP包的包体中:因此,GET提交的数据会在地址栏中显示出来,而PO

  • python中的数组赋值与拷贝的区别详解

    具体的注解我已经写在了程序里面:通俗的解释了python里面的浅拷贝与深拷贝的不同,请看程序. # -*- coding: utf-8 -*- import numpy as np import copy as cp import matplotlib.pyplot as plt import time import math fig = plt.figure() ax = fig.add_subplot(241) # 定义一个多维数组 x = np.array([[1, 2, 3], [4,

随机推荐