pytorch 带batch的tensor类型图像显示操作

项目场景

pytorch训练时我们一般把数据集放到数据加载器里,然后分批拿出来训练。训练前我们一般还要看一下训练数据长啥样,也就是训练数据集可视化。

那么如何显示dataloader里面带batch的tensor类型的图像呢?

显示图像

绘图最常用的库就是matplotlib:

pip install matplotlib

显示图像会用到matplotlib.pyplot.imshow方法。查阅官方文档可知,该方法接收的图像的通道数要放到后面:

数据加载器中数据的维度是[B, C, H, W],我们每次只拿一个数据出来就是[C, H, W],而matplotlib.pyplot.imshow要求的输入维度是[H, W, C],所以我们需要交换一下数据维度,把通道数放到最后面,这里用到pytorch里面的permute方法(transpose方法也行,不过要交换两次,没这个方便,numpy中的transpose方法倒是可以一次交换完成)

用法示例如下:

>>> x = torch.randn(2, 3, 5)
>>> x.size()
torch.Size([2, 3, 5])
>>> x.permute(1, 2, 0).size()
torch.Size([3, 5, 2])

代码示例

#%% 导入模块
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
#%% 下载数据集
train_file = datasets.MNIST(
    root='./dataset/',
    train=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ]),
    download=True
)
#%% 制作数据加载器
train_loader = DataLoader(
    dataset=train_file,
    batch_size=9,
    shuffle=True
)
#%% 训练数据可视化
images, labels = next(iter(train_loader))
print(images.size())  # torch.Size([9, 1, 28, 28])
plt.figure(figsize=(9, 9))
for i in range(9):
    plt.subplot(3, 3, i+1)
    plt.title(labels[i].item())
    plt.imshow(images[i].permute(1, 2, 0), cmap='gray')
    plt.axis('off')
plt.show()

这里以mnist数据集为例,演示一下显示效果。我这个代码其实还有一点小问题。数据增强的时候我不是进行标准化了嘛,就是在第7行代码:Normalize((0.1307,), (0.3081,))。

所以,如果你想查看训练集的原始图像,还得反标准化。

标准化:image = (image-mean)/std

反标准化:image = image*std+mean

我拿imagenet中的一个蚂蚁和蜜蜂的子集做了一下实验,标准化前后的区别还是很明显的:

最终效果

补充:PIL,plt显示tensor类型的图像

该方法针对显示Dataloader读取的图像

PIL 与plt中对应操作不同,但原理是一样的,我试过用下方代码Image的方法在plt上show失败了,原因暂且不知。

 # 方法1:Image.show()
 # transforms.ToPILImage()中有一句
 # npimg = np.transpose(pic.numpy(), (1, 2, 0))
 # 因此pic只能是3-D Tensor,所以要用image[0]消去batch那一维
 img = transforms.ToPILImage(image[0])
 img.show()

 # 方法2:plt.imshow(ndarray)
 img = image[0] # plt.imshow()只能接受3-D Tensor,所以也要用image[0]消去batch那一维
 img = img.numpy() # FloatTensor转为ndarray
 img = np.transpose(img, (1,2,0)) # 把channel那一维放到最后
 # 显示图片
 plt.imshow(img)
 plt.show()
 cnt += 1

以上为个人经验,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • pytorch的batch normalize使用详解

    torch.nn.BatchNorm1d() 1.BatchNorm1d(num_features, eps = 1e-05, momentum=0.1, affine=True) 对于2d或3d输入进行BN.在训练时,该层计算每次输入的均值和方差,并进行平行移动.移动平均默认的动量为0.1.在验证时,训练求得的均值/方差将用于标准化验证数据. num_features:表示输入的特征数.该期望输入的大小为'batch_size x num_features [x width]' Shape: 

  • pytorch Variable与Tensor合并后 requires_grad()默认与修改方式

    pytorch更新完后合并了Variable与Tensor torch.Tensor()能像Variable一样进行反向传播的更新,返回值为Tensor Variable自动创建tensor,且返回值为Tensor,(所以以后不需要再用Variable) Tensor创建后,默认requires_grad=Flase 可以通过xxx.requires_grad_()将默认的Flase修改为True 下面附代码及官方文档代码: import torch from torch.autograd im

  • Pytorch中TensorBoard及torchsummary的使用详解

    1.TensorBoard神经网络可视化工具 TensorBoard是一个强大的可视化工具,在pytorch中有两种调用方法: 1.from tensorboardX import SummaryWriter 这种方法是在官方还不支持tensorboard时网上有大神写的 2.from torch.utils.tensorboard import SummaryWriter 这种方法是后来更新官方加入的 1.1 调用方法 1.1.1 创建接口SummaryWriter 功能:创建接口 调用方法:

  • 解决pytorch下只打印tensor的数值不打印出device等信息的问题

    torch.Tensor类型的数据loss和acc打印时 如果写成以下写法 print('batch_loss: '+str(loss.data)+'batch acc: '+str(acc.data)) 则不仅会打印出loss和acc的值,还会打印出device信息和 tensor字样,如下: 如果仅想打印出数值,使得打印出的信息更加简洁 则要用以下写法 print('batch_loss: {:.3f} batch acc: {:.3f}'.format(loss.data, acc.dat

  • pytorch dataloader 取batch_size时候出现bug的解决方式

    1. RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 342 and 281 in dimension 3 at /pytorch/aten/src/TH/generic/THTensorMoreMath.cpp:1333 2. RuntimeError: invalid argument 0: Sizes of tensors must match except i

  • 浅谈pytorch中stack和cat的及to_tensor的坑

    初入计算机视觉遇到的一些坑 1.pytorch中转tensor x=np.random.randint(10,100,(10,10,10)) x=TF.to_tensor(x) print(x) 这个函数会对输入数据进行自动归一化,比如有时候我们需要将0-255的图片转为numpy类型的数据,则会自动转为0-1之间 2.stack和cat之间的差别 stack x=torch.randn((1,2,3)) y=torch.randn((1,2,3)) z=torch.stack((x,y))#默

  • pytorch方法测试详解——归一化(BatchNorm2d)

    测试代码: import torch import torch.nn as nn m = nn.BatchNorm2d(2,affine=True) #权重w和偏重将被使用 input = torch.randn(1,2,3,4) output = m(input) print("输入图片:") print(input) print("归一化权重:") print(m.weight) print("归一化的偏重:") print(m.bias)

  • pytorch 带batch的tensor类型图像显示操作

    项目场景 pytorch训练时我们一般把数据集放到数据加载器里,然后分批拿出来训练.训练前我们一般还要看一下训练数据长啥样,也就是训练数据集可视化. 那么如何显示dataloader里面带batch的tensor类型的图像呢? 显示图像 绘图最常用的库就是matplotlib: pip install matplotlib 显示图像会用到matplotlib.pyplot.imshow方法.查阅官方文档可知,该方法接收的图像的通道数要放到后面: 数据加载器中数据的维度是[B, C, H, W],

  • pytorch常见的Tensor类型详解

    Tensor有不同的数据类型,每种类型分别有对应CPU和GPU版本(HalfTensor除外).默认的Tensor是FloatTensor,可通过torch.set_default_tensor_type修改默认tensor类型(如果默认类型为GPU tensor,则所有操作都将在GPU上进行). Tensor的类型对分析内存占用很有帮助,例如,一个size为(1000,1000,1000)的FloatTensor,它有1000*1000*1000=10^9个元素,每一个元素占用32bit/8=

  • pytorch: tensor类型的构建与相互转换实例

    Summary 主要包括以下三种途径: 使用独立的函数: 使用torch.type()函数: 使用type_as(tesnor)将张量转换为给定类型的张量. 使用独立函数 import torch tensor = torch.randn(3, 5) print(tensor) # torch.long() 将tensor投射为long类型 long_tensor = tensor.long() print(long_tensor) # torch.half()将tensor投射为半精度浮点类型

  • Pytorch 使用 nii数据做输入数据的操作

    使用pix2pix-gan做医学图像合成的时候,如果把nii数据转成png格式会损失很多信息,以为png格式图像的灰度值有256阶,因此直接使用nii的医学图像做输入会更好一点. 但是Pythorch中的Dataloader是不能直接读取nii图像的,因此加一个CreateNiiDataset的类. 先来了解一下pytorch中读取数据的主要途径--Dataset类.在自己构建数据层时都要基于这个类,类似于C++中的虚基类. 自己构建的数据层包含三个部分 class Dataset(object

  • Java 手动解析不带引号的JSON字符串的操作

    1 需求说明 项目中遇到了一批不带引号的类JSON格式的字符串: {Name:Heal,Age:20,Tag:[Coding,Reading]} 需要将其解析成JSON对象, 然后插入到Elasticsearch中, 当作Object类型的对象存储起来. 在对比了阿里的FastJson.Google的Gson, 没找到想要的功能 ( 可能是博主不够仔细, 有了解的童学留言告诉我下呀

  • mybatis查询实现返回List<Map>类型数据操作

    如下所示: **只要设定resultType而不设定resultMap就可以了**: < select id = "selectByPage" parameterType = "java.util.Map" resultType="java.util.Map" > select rs.*, rssetting.*, cp.STOCK_CODE, cp.UNAME from RS rs left join T_COMPANY cp on

  • 解决Pytorch中Batch Normalization layer踩过的坑

    1. 注意momentum的定义 Pytorch中的BN层的动量平滑和常见的动量法计算方式是相反的,默认的momentum=0.1 BN层里的表达式为: 其中γ和β是可以学习的参数.在Pytorch中,BN层的类的参数有: CLASS torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 每个参数具体含义参见文档,需要注意的是,affine定义了BN层的

  • Java Date时间类型的操作实现

    本文主要介绍Java Date 日期类型,以及Calendar的怎么获取时间,然后写成时间工具类里面有下面这些方法: - 时间转字符串(有默认时间格式,带时间格式) - 字符串转时间(有默认时间格式,带时间格式) - 计算两个日期之间相差的天数 - 计算当前时间多少天以后的日期 - 判断是否是日期格式 代码 很多说明都注释在代码上: package datedemo; import java.text.SimpleDateFormat; import java.util.Calendar; im

  • 浅谈Python 集合(set)类型的操作——并交差

    阅读目录 •介绍 •基本操作 •函数操作 介绍 python的set是一个无序不重复元素集,基本功能包括关系测试和消除重复元素. 集合对象还支持并.交.差.对称差等. sets 支持 x in set. len(set).和 for x in set.作为一个无序的集合,sets不记录元素位置或者插入点.因此,sets不支持 indexing, slicing, 或其它类序列(sequence-like)的操作. 基本操作 >>> x = set("jihite")

随机推荐