Pytorch DataLoader shuffle验证方式

shuffle = False时,不打乱数据顺序

shuffle = True,随机打乱

import numpy as np
import h5py
import torch
from torch.utils.data import DataLoader, Dataset
h5f = h5py.File('train.h5', 'w');
data1 = np.array([[1,2,3],
               [2,5,6],
              [3,5,6],
              [4,5,6]])
data2 = np.array([[1,1,1],
                   [1,2,6],
                  [1,3,6],
                  [1,4,6]])
h5f.create_dataset(str('data'), data=data1)
h5f.create_dataset(str('label'), data=data2)
class Dataset(Dataset):
    def __init__(self):
        h5f = h5py.File('train.h5', 'r')
        self.data = h5f['data']
        self.label = h5f['label']
    def __getitem__(self, index):
        data = torch.from_numpy(self.data[index])
        label = torch.from_numpy(self.label[index])
        return data, label

    def __len__(self):
        assert self.data.shape[0] == self.label.shape[0], "wrong data length"
        return self.data.shape[0] 

dataset_train = Dataset()
loader_train = DataLoader(dataset=dataset_train,
                           batch_size=2,
                           shuffle = True)

for i, data in enumerate(loader_train):
    train_data, label = data
    print(train_data)
 

pytorch DataLoader使用细节

背景:

我一开始是对数据扩增这一块有疑问, 只看到了数据变换(torchvisiom.transforms),但是没看到数据扩增, 后来搞明白了, 数据扩增在pytorch指的是torchvisiom.transforms + torch.utils.data.DataLoader+多个epoch共同作用下完成的,

数据变换共有以下内容

composed = transforms.Compose([transforms.Resize((448, 448)), #  resize
                               transforms.RandomCrop(300), # random crop
                               transforms.ToTensor(),
                               transforms.Normalize(mean=[0.5, 0.5, 0.5],  # normalize
                                                    std=[0.5, 0.5, 0.5])])

简单的数据读取类, 进返回PIL格式的image:

class MyDataset(data.Dataset):
    def __init__(self, labels_file, root_dir, transform=None):
        with open(labels_file) as csvfile:
            self.labels_file = list(csv.reader(csvfile))
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.labels_file)

    def __getitem__(self, idx):
        im_name = os.path.join(root_dir, self.labels_file[idx][0])
        im = Image.open(im_name)

        if self.transform:
            im = self.transform(im)

        return im

下面是主程序

labels_file = "F:/test_temp/labels.csv"
root_dir = "F:/test_temp"
dataset_transform = MyDataset(labels_file, root_dir, transform=composed)
dataloader = data.DataLoader(dataset_transform, batch_size=1, shuffle=False)
"""原始数据集共3张图片, 以batch_size=1, epoch为2 展示所有图片(共6张)  """
for eopch in range(2):
    plt.figure(figsize=(6, 6))
    for ind, i in enumerate(dataloader):
        a = i[0, :, :, :].numpy().transpose((1, 2, 0))
        plt.subplot(1, 3, ind+1)
        plt.imshow(a)

从上述图片总可以看到, 在每个eopch阶段实际上是对原始图片重新使用了transform, , 这就造就了数据的扩增

以上为个人经验,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • Pytorch在dataloader类中设置shuffle的随机数种子方式

    如题:Pytorch在dataloader类中设置shuffle的随机数种子方式 虽然实验结果差别不大,但是有时候也悬殊两个百分点 想要复现实验结果 发现用到随机数的地方就是dataloader类中封装的shuffle属性 查了半天没有关于这个的设置,最后在设置随机数种子里面找到了答案 以下方法即可: def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed

  • pytorch 实现多个Dataloader同时训练

    看代码吧~ 如果两个dataloader的长度不一样,那就加个: from itertools import cycle 仅使用zip,迭代器将在长度等于最小数据集的长度时耗尽. 但是,使用cycle时,我们将再次重复最小的数据集,除非迭代器查看最大数据集中的所有样本. 补充:pytorch技巧:自定义数据集 torch.utils.data.DataLoader 及Dataset的使用 本博客中有可直接运行的例子,便于直观的理解,在torch环境中运行即可. 1. 数据传递机制 在 pytor

  • 我对PyTorch dataloader里的shuffle=True的理解

    对shuffle=True的理解: 之前不了解shuffle的实际效果,假设有数据a,b,c,d,不知道batch_size=2后打乱,具体是如下哪一种情况: 1.先按顺序取batch,对batch内打乱,即先取a,b,a,b进行打乱: 2.先打乱,再取batch. 证明是第二种 shuffle (bool, optional): set to ``True`` to have the data reshuffled at every epoch (default: ``False``). if

  • 解决Pytorch dataloader时报错每个tensor维度不一样的问题

    使用pytorch的dataloader报错: RuntimeError: stack expects each tensor to be equal size, but got [2] at entry 0 and [1] at entry 1 1. 问题描述 报错定位:位于定义dataset的代码中 def __getitem__(self, index): ... return y #此处报错 报错内容 File "D:\python\lib\site-packages\torch\uti

  • pytorch锁死在dataloader(训练时卡死)

    1.问题描述 2.解决方案 (1)Dataloader里面不用cv2.imread进行读取图片,用cv2.imread还会带来一系列的不方便,比如不能结合torchvision进行数据增强,所以最好用PIL 里面的Image.open来读图片.(并不适用本例) (2)将DataLoader 里面的参变量num_workers设置为0,但会导致数据的读取很慢,拖慢整个模型的训练.(并不适用本例) (3)如果用了cv2.imread,不想改代码的,那就加两条语句,来关闭Opencv的多线程:cv2.

  • pytorch中DataLoader()过程中遇到的一些问题

    如下所示: RuntimeError: stack expects each tensor to be equal size, but got [3, 60, 32] at entry 0 and [3, 54, 32] at entry 2 train_dataset = datasets.ImageFolder( traindir, transforms.Compose([ transforms.Resize((224)) ### 原因是 transforms.Resize() 的参数设置问

  • Pytorch dataloader在加载最后一个batch时卡死的解决

    问题: 自己写了个dataloader,为了部署方便,用OpenCV的接口进行数据读取,而没有用PIL,代码大致如下: def __getitem__(self, idx): sample = self.samples[idx] img = cv2.imread(sample[0]) img = cv2.resize(img, tuple(self.input_size)) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # if not self.val

  • Pytorch DataLoader shuffle验证方式

    shuffle = False时,不打乱数据顺序 shuffle = True,随机打乱 import numpy as np import h5py import torch from torch.utils.data import DataLoader, Dataset h5f = h5py.File('train.h5', 'w'); data1 = np.array([[1,2,3], [2,5,6], [3,5,6], [4,5,6]]) data2 = np.array([[1,1,

  • 使用 PyTorch 实现 MLP 并在 MNIST 数据集上验证方式

    简介 这是深度学习课程的第一个实验,主要目的就是熟悉 Pytorch 框架.MLP 是多层感知器,我这次实现的是四层感知器,代码和思路参考了网上的很多文章.个人认为,感知器的代码大同小异,尤其是用 Pytorch 实现,除了层数和参数外,代码都很相似. Pytorch 写神经网络的主要步骤主要有以下几步: 1 构建网络结构 2 加载数据集 3 训练神经网络(包括优化器的选择和 Loss 的计算) 4 测试神经网络 下面将从这四个方面介绍 Pytorch 搭建 MLP 的过程. 项目代码地址:la

  • Pytorch DataLoader 变长数据处理方式

    关于Pytorch中怎么自定义Dataset数据集类.怎样使用DataLoader迭代加载数据,这篇官方文档已经说得很清楚了,这里就不在赘述. 现在的问题:有的时候,特别对于NLP任务来说,输入的数据可能不是定长的,比如多个句子的长度一般不会一致,这时候使用DataLoader加载数据时,不定长的句子会被胡乱切分,这肯定是不行的. 解决方法是重写DataLoader的collate_fn,具体方法如下: # 假如每一个样本为: sample = { # 一个句子中各个词的id 'token_li

  • Pytorch.nn.conv2d 过程验证方式(单,多通道卷积过程)

    今天在看文档的时候,发现pytorch 的conv操作不是很明白,于是有了一下记录 首先提出两个问题: 1.输入图片是单通道情况下的filters是如何操作的? 即一通道卷积核卷积过程 2.输入图片是多通道情况下的filters是如何操作的? 即多通道多个卷积核卷积过程 这里首先贴出官方文档: classtorch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1

  • pytorch dataloader 取batch_size时候出现bug的解决方式

    1. RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 342 and 281 in dimension 3 at /pytorch/aten/src/TH/generic/THTensorMoreMath.cpp:1333 2. RuntimeError: invalid argument 0: Sizes of tensors must match except i

  • 解决pytorch DataLoader num_workers出现的问题

    最近在学pytorch,在使用数据分批训练时在导入数据是使用了 DataLoader 在参数 num_workers的设置上使程序出现运行没有任何响应的结果 ,看看代码 import torch #导入模块 import torch.utils.data as Data BATCH_SIZE=8 #每一批的数据量 x=torch.linspace(1,10,10) #定义X为 1 到 10 等距离大小的数 y=torch.linspace(10,1,10) #转换成torch能识别的Datase

  • 详解ASP.NET七大身份验证方式以及解决方案

    在B/S系统开发中,经常需要使用"身份验证".因为web应用程序非常特殊,和传统的C/S程序不同,默认情况下(不采用任何身份验证方式和权限控制手段),当你的程序在互联网/局域网上公开后,任何人都能够访问你的web应用程序的资源,这样很难保障应用程序安全性.通俗点来说:对于大多数的内部系统.业务支撑平台等而言,用户必须登录,否则无法访问和操作任何页面.而对于互联网(网站)而言,又有些差异,因为通常网站的大部分页面和信息都是对外公开的,只有涉及到注册用户个人信息的操作,或者网站的后台管理等

  • SQL server 2008 更改登录验证方式的方法

    前言:之前在敲学生的时候也遇到过这个问题,但是当时没有能及时总结,导致这次遇到问题还要重新去查,所以今天就做个总结,方便自己也帮助他人! 如果在安装过程中选择"Windows 身份验证模式",则 sa 登录名将被禁用.如果稍后将身份验证模式更改为"SQL Server 和 Windows 身份验证模式",则 sa 登录名仍处于禁用状态.若要启用 sa 登录帐户,请使用 ALTER LOGIN 语句. 安全说明: sa 帐户是一个广为人知的 SQL Server 帐户

  • Mongodb常用的身份验证方式

    1. 介绍 不管数据库是在多安全的环境或者本地环境,给数据库建立一个安全的环境是很有必要的. Mongodb提供了一系列的 安全功能 ,这里介绍一种很常用的身份验证方式. 2. 开启验证 默认情况下,只要在启动数据库的时候没有加上 --auth 选项,就是没有身份验证功能的,所有客户端都可以进行所有权限的操作. 如果加上过后,我们就可以通过安全的身份验证连接数据库.如果要在数据库中进行身份验证,可以通过 db.auth(username, password) ,如果验证成功则返回1,反之. 3.

随机推荐