Pytorch加载数据集的方式总结及补充

目录
  • 前言
  • 一、自己重写定义(Dataset、DataLoader)
  • 二、用Pytorch自带的类(ImageFolder、datasets、DataLoader)
    • 2.1 加载自己的数据集
      • 2.1.1 ImageFolder介绍
      • 2.2.2 ImageFolder加载数据集完整例子
    • 2.2 加载常见的数据集
  • 三、总结
  • 四、transforms变换讲解
  • 五、DataLoader的补充
  • 总结

前言

在用Pytorch加载数据集时,看GitHub上的代码经常会用到ImageFolder、DataLoader等一系列方法,而这些方法又是来自于torchvision、torch.utils.data。除加载数据集外,还有torchvision中的transforms对数据集预处理…等等等等。这个data,那个dataset…这一系列下来,不加注意的话实在有点打脑壳。看别人的代码加载数据集挺简单,但是自己用的时候,尤其是加载自己所制作的数据集的时候,就会茫然无措。别无他法,抱着硬啃的心态,查阅了其他博文,通过代码实验,终于是理清楚了思路。

Pytorch加载数据集可以分两种大的情况:一、自己重写定义; 二、用Pytorch自带的类。第二种里面又有多种不同的方法(datasets、 ImageFolder等),但这些方法都有相同的处理规律。我理解的,无论是哪种情况,加载数据集都需要构造数据加载器数据装载器(后者生成的是可迭代的数据)。现将这两种情况一一说明。

一、自己重写定义(Dataset、DataLoader)

目前我们有自己制作的数据以及数据标签,但是有时候感觉不太适合直接用Pytorch自带加载数据集的方法。我们可以自己来重写定义一个类,这个类继承于 torch.utils.data.Dataset,同时我们需要重写这个类里面的两个方法 _ getitem__ () 和__ len()__函数。

如下所示。这两种方法如何构造以及具体的细节可以查看其他的博客。len方法必须返回数据的长度,getitem方法必须返回数据以及标签。

import torch
import numpy as np

# 定义GetLoader类,继承Dataset方法,并重写__getitem__()和__len__()方法
class GetLoader(torch.utils.data.Dataset):
	# 初始化函数,得到数据
    def __init__(self, data_root, data_label):
        self.data = data_root
        self.label = data_label
    # index是根据batchsize划分数据后得到的索引,最后将data和对应的labels进行一起返回
    def __getitem__(self, index):
        data = self.data[index]
        labels = self.label[index]
        return data, labels
    # 该函数返回数据大小长度,目的是DataLoader方便划分,如果不知道大小,DataLoader会一脸懵逼
    def __len__(self):
        return len(self.data)

# 随机生成数据,大小为10 * 20列
source_data = np.random.rand(10, 20)
# 随机生成标签,大小为10 * 1列
source_label = np.random.randint(0,2,(10, 1))
# 通过GetLoader将数据进行加载,返回Dataset对象,包含data和labels
torch_data = GetLoader(source_data, source_label)

通过上述的程序,我们构造了一个数据加载器torch_data,但是还是不能直接传入网络中。接下来需要构造数据装载器,产生可迭代的数据,再传入网络中。DataLoader类完成这个工作。

torch.utils.data.DataLoader(dataset,batch_size,shuffle,drop_last,num_workers)

参数解释:

1.dataset     : 加载torch.utils.data.Dataset对象数据
2.batch_size  : 每个batch的大小,将我们的数据分批输入到网络中
3.shuffle     : 是否对数据进行打乱
4.drop_last   : 是否对无法整除的最后一个datasize进行丢弃
5.num_workers : 表示加载的时候子进程数

结合我们自己定义的加载数据集类,可以如下使用。后面将data和label传入我们定义的模型中。

...
torch_data = GetLoader(source_data, source_label)

from torch.utils.data import DataLoader
datas = DataLoader(torch_data, batch_size = 4, shuffle = True, drop_last = False, num_workers = 2)
for i, (data, label) in enumerate(datas):
	# i表示第几个batch, data表示batch_size个原始的数据,label代表batch_size个数据的标签
    print("第 {} 个Batch \n{}".format(i, data))

二、用Pytorch自带的类(ImageFolder、datasets、DataLoader)

2.1 加载自己的数据集

2.1.1 ImageFolder介绍

和第一种情况不一样,我们不需要在代码上自己定义数据集类了,而是将数据集按照一定的格式摆放,调用ImageFolder类即可。这种是在调用Pytorch内部的API,所以我们自己的数据集得需要按照API内部所规定的存放格式。torchvision.datasets.ImageFolder 要求数据集按照如下方式组织。根目录 root 下存储的是类别文件夹(如cat,dog),每个类别文件夹下存储相应类别的图像(如xxx.png)

A generic data loader where the images are arranged in this way:

root/dog/xxx.png
        root/dog/xxy.png
        root/dog/xxz.png

root/cat/123.png
        root/cat/nsdf3.png
        root/cat/asd932_.png

torchvision.datasets.ImageFolder有以下参数:

dataset=torchvision.datasets.ImageFolder(
                       root, transform=None,
                       target_transform=None,
                       loader=<function default_loader>,
                       is_valid_file=None)

参数解释:

1.root:根目录,在root目录下,应该有不同类别的子文件夹;
    |--data(root)
        |--train
            |--cat
            |--dog
        |--valid
            |--cat
            |--dog        
2.transform:对图片进行预处理的操作,原始图像作为一个输入,返回的是transform变换后的图片;
3.target_transform:对图片类别进行预处理的操作,输入为 target,输出对其的转换。 如果不传该参数,即对target不做任何转换,返回的顺序索引 0,1, 2…
4.loader:表示数据集加载方式,通常默认加载方式即可;
5.is_valid_file:获取图像文件的路径并检查该文件是否为有效文件的函数(用于检查损坏文件)

作为torchvision.datasets.ImageFolder的返回,会有以下三种属性:

(1)self.classes:用一个 list 保存类别名称

(2)self.class_to_idx:类别对应的索引,与不做任何转换返回的 target 对应

(3)self.imgs:保存(img_path, class) tuple的list

以猫狗类别举例,各属性输出如下所示:

print(dataset.classes)  #根据分的文件夹的名字来确定的类别
print(dataset.class_to_idx) #按顺序为这些类别定义索引为0,1...
print(dataset.imgs) #返回从所有文件夹中得到的图片的路径以及其类别
'''
输出:
['cat', 'dog']
{'cat': 0, 'dog': 1}
[('./data/train\\cat\\1.jpg', 0),
 ('./data/train\\cat\\2.jpg', 0),
 ('./data/train\\dog\\1.jpg', 1),
 ('./data/train\\dog\\2.jpg', 1)]
'''

2.2.2 ImageFolder加载数据集完整例子

# 5. 将文件夹数据导入
train_loader = torch.utils.data.DataLoader(dataset,
                                           batch_size = batch_size, shuffle=True,
                                           num_workers = 2)
# 6. 传入网络进行训练
for epoch in range(epochs):
    train_bar = tqdm(train_loader, file = sys.stdout)
    for step, data in enumerate(train_bar):
    ...

和第一种情况自己重写定义一样,上述的代码仅仅完成了数据加载器的定义。这样是不能直接传入网络中进行训练的,需要再构造一个可迭代的数据装载器。DataLoader类的使用方式上文中有详细介绍。

# 5. 将文件夹数据导入
train_loader = torch.utils.data.DataLoader(dataset,
                                           batch_size = batch_size, shuffle=True,
                                           num_workers = 2)
# 6. 传入网络进行训练
for epoch in range(epochs):
    train_bar = tqdm(train_loader, file = sys.stdout)
    for step, data in enumerate(train_bar):
    ...

2.2 加载常见的数据集

有些数据集是公共的,比如常见的MNIST,CIFAR10,SVHN等等。这些数据集在Pytorch中可以通过代码就可以下载、加载。如下代码所示。用torchvision中的datasets类下载数据集,并还是结合DataLoader来构建可直接传入网络的数据装载器。

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

def dataloader(dataset, input_size, batch_size, split='train'):
    transform = transforms.Compose([
        					transforms.Resize((input_size, input_size)),
       					    transforms.ToTensor(),
        					transforms.Normalize(mean=[0.5], std=[0.5])
    ])
    if dataset == 'mnist':
        data_loader = DataLoader(
            datasets.MNIST('data/mnist', train=True, download=True, transform=transform),
            batch_size=batch_size, shuffle=True)
    elif dataset == 'fashion-mnist':
        data_loader = DataLoader(
            datasets.FashionMNIST('data/fashion-mnist', train=True, download=True, transform=transform),
            batch_size=batch_size, shuffle=True)
    elif dataset == 'cifar10':
        data_loader = DataLoader(
            datasets.CIFAR10('data/cifar10', train=True, download=True, transform=transform),
            batch_size=batch_size, shuffle=True)
    elif dataset == 'svhn':
        data_loader = DataLoader(
            datasets.SVHN('data/svhn', split=split, download=True, transform=transform),
            batch_size=batch_size, shuffle=True)
    elif dataset == 'stl10':
        data_loader = DataLoader(
            datasets.STL10('data/stl10', split=split, download=True, transform=transform),
            batch_size=batch_size, shuffle=True)
    elif dataset == 'lsun-bed':
        data_loader = DataLoader(
            datasets.LSUN('data/lsun', classes=['bedroom_train'], transform=transform),
            batch_size=batch_size, shuffle=True)

    return data_loader

三、总结

至于觉得加载数据集比较难的很大的原因,我感觉是Dataset、datasets、DataLoader以及torch.utils.data、torchvision种类太多,有点混乱。上面的梳理,我的理解是无论是哪种方式,终端还是需要DataLoader整合。作为加载数据集的前端,用自己定义的、用ImageFolder的、还是用datasets加载常用数据集,都是在构造数据加载器,而且构造起来也并不复杂。梳理清晰后,相信对Pytorch加载数据集有了更进一步的理解。

四、transforms变换讲解

torchvision.transforms是Pytorch中的图像预处理包。一般定义在加载数据集之前,用transforms中的Compose类把多个步骤整合到一起,而这些步骤是transforms中的函数。

transforms中的函数有这些:

函数 含义
transforms.Resize 把给定的图片resize到given size
transforms.Normalize 用均值和标准差归一化张量图像
transforms.Totensor 可以将PIL和numpy格式的数据从[0,255]范围转换到[0,1] ; <br /另外原始数据的shape是(H x W x C),通过transforms.ToTensor()后shape会变为(C x H x W)
transforms.RandomGrayscale 将图像以一定的概率转换为灰度图像
transforms.ColorJitter 随机改变图像的亮度对比度和饱和度
transforms.Centercrop 在图片的中间区域进行裁剪
transforms.RandomCrop 在一个随机的位置进行裁剪
transforms.FiceCrop 把图像裁剪为四个角和一个中心
transforms.RandomResizedCrop 将PIL图像裁剪成任意大小和纵横比
transforms.ToPILImage convert a tensor to PIL image
transforms.RandomHorizontalFlip 以0.5的概率水平翻转给定的PIL图像
transforms.RandomVerticalFlip 以0.5的概率竖直翻转给定的PIL图像
transforms.Grayscale 将图像转换为灰度图像

不同函数对应有不同的属性,用transforms.Compose将不同的操作整合在一起,如下所示。

transforms.Compose([transforms.RandomResizedCrop(224),
 		    	   transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

五、DataLoader的补充

数据加载器,结合了数据集和取样器,并且可以提供多个线程处理数据集。

在训练模型时使用到此函数,用来把训练数据分成多个小组,此函数每次抛出一组数据。直至把所有的数据都抛出。就是做一个数据的初始化。

用下面的例子测试:

"""
    批训练,把数据变成一小批一小批数据进行训练。
    DataLoader就是用来包装所使用的数据,每次抛出一批数据
"""
import torch
import torch.utils.data as Data

BATCH_SIZE = 5

x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)
# 把数据放在数据库中
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
    # 从数据库中每次抽出batch size个样本
    dataset=torch_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
)

def show_batch():
    for epoch in range(3):
        for step, (batch_x, batch_y) in enumerate(loader):
            # training
            print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))

if __name__ == '__main__':
    show_batch()

结果如下所示。仔细观察:

每一个step,batch_x是不会重合的,batch_y里面的值也是不会重合的(第一个step中,batch_x:tensor([ 3., 10., 6., 2., 8.]);第二个step中batch_x:tensor([5., 9., 7., 4., 1.])),说明DataLoader将数据打乱后,每次选用其中的Batch_size个数据且不会重复;

其二,batch_x 和 batch_y对应的索引之和相等,这说明DataLoader对图像和标签打乱顺序时,同时按照某一规律打乱,并不会造成标签和图像出现不对应的情况。

其三,在不同的epoch之间,每次数据也是不同的,说明DataLoader每次被调用时,都会重新打乱一次。

steop:0, batch_x:tensor([ 3., 10.,  6.,  2.,  8.]), batch_y:tensor([8., 1., 5., 9., 3.])
steop:1, batch_x:tensor([5., 9., 7., 4., 1.]), batch_y:tensor([ 6.,  2.,  4.,  7., 10.])
steop:0, batch_x:tensor([8., 3., 1., 2., 9.]), batch_y:tensor([ 3.,  8., 10.,  9.,  2.])
steop:1, batch_x:tensor([10., 5.,  4.,  7.,  6.]), batch_y:tensor([1., 6., 7., 4., 5.])
steop:0, batch_x:tensor([5., 8., 4., 3., 7.]), batch_y:tensor([6., 3., 7., 8., 4.])
steop:1, batch_x:tensor([ 2., 10.,  6.,  9.,  1.]), batch_y:tensor([ 9.,  1.,  5.,  2., 10.])

总结

到此这篇关于Pytorch加载数据集的方式总结及补充的文章就介绍到这了,更多相关Pytorch加载数据集内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • PyTorch如何创建自己的数据集

    目录 PyTorch创建自己的数据集 pytorch常用数据集的使用 PyTorch创建自己的数据集 图片文件在同一的文件夹下 思路是继承 torch.utils.data.Dataset,并重点重写其 __getitem__方法,示例代码如下: class ImageFolder(Dataset):     def __init__(self, folder_path):         self.files = sorted(glob.glob('%s/*.*' % folder_path)

  • 使用pytorch读取数据集

    目录 pytorch读取数据集 第一种 第二种 第三种 pytorch学习记录 注意事项 pytorch读取数据集 使用pytorch读取数据集一般有三种情况 第一种 读取官方给的数据集,例如Imagenet,CIFAR10,MNIST等 这些库调用torchvision.datasets.XXXX()即可,例如想要读取MNIST数据集 import torch import torch.nn as nn import torch.utils.data as Data import torchv

  • pytorch加载自己的图片数据集的2种方法详解

    目录 ImageFolder 加载数据集 使用pytorch提供的Dataset类创建自己的数据集. Dataset加载数据集 总结 pytorch加载图片数据集有两种方法. 1.ImageFolder 适合于分类数据集,并且每一个类别的图片在同一个文件夹, ImageFolder加载的数据集, 训练数据为文件件下的图片, 训练标签是对应的文件夹, 每个文件夹为一个类别 导入ImageFolder()包 from torchvision.datasets import ImageFolder 在

  • pytorch加载自己的数据集源码分享

    目录 一.标准的数据集流程梳理 数据来源 二.实现加载自己的数据集 1. 保存在txt文件中(生成训练集和测试集,其实这里的训练集以及测试集也都是用文本文件的形式保存下来的) 2. 在继承dataset类LoadData的三个函数里调用train.txt以及test.txt实现相关功能 三.源码 一.标准的数据集流程梳理 分为几个步骤数据准备以及加载数据库–>数据加载器的调用或者设计–>批量调用进行训练或者其他作用 数据来源 直接读取了x和y的数据变量,对比后面的就从把对应的路径写进了文本文件

  • PyTorch手写数字数据集进行多分类

    目录 一.实现过程 0.导包 1.准备数据 2.设计模型 3.构造损失函数和优化器 4.训练和测试 二.参考文献 一.实现过程 本文对经典手写数字数据集进行多分类,损失函数采用交叉熵,激活函数采用ReLU,优化器采用带有动量的mini-batchSGD算法. 所有代码如下: 0.导包 import torch from torchvision import transforms,datasets from torch.utils.data import DataLoader import tor

  • Pytorch加载数据集的方式总结及补充

    目录 前言 一.自己重写定义(Dataset.DataLoader) 二.用Pytorch自带的类(ImageFolder.datasets.DataLoader) 2.1 加载自己的数据集 2.1.1 ImageFolder介绍 2.2.2 ImageFolder加载数据集完整例子 2.2 加载常见的数据集 三.总结 四.transforms变换讲解 五.DataLoader的补充 总结 前言 在用Pytorch加载数据集时,看GitHub上的代码经常会用到ImageFolder.DataLo

  • PyTorch加载数据集梯度下降优化

    目录 一.实现过程 1.准备数据 2.设计模型 3.构造损失函数和优化器 4.训练过程 5.结果展示 二.参考文献 一.实现过程 1.准备数据 与PyTorch实现多维度特征输入的逻辑回归的方法不同的是:本文使用DataLoader方法,并继承DataSet抽象类,可实现对数据集进行mini_batch梯度下降优化. 代码如下: import torch import numpy as np from torch.utils.data import Dataset,DataLoader clas

  • PyTorch加载自己的数据集实例详解

    数据预处理在解决深度学习问题的过程中,往往需要花费大量的时间和精力. 数据处理的质量对训练神经网络来说十分重要,良好的数据处理不仅会加速模型训练, 更会提高模型性能.为解决这一问题,PyTorch提供了几个高效便捷的工具, 以便使用者进行数据处理或增强等操作,同时可通过并行化加速数据加载. 数据集存放大致有以下两种方式: (1)所有数据集放在一个目录下,文件名上附有标签名,数据集存放格式如下: root/cat_dog/cat.01.jpg root/cat_dog/cat.02.jpg ...

  • pytorch加载语音类自定义数据集的方法教程

    前言 pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合 torch.utils.data.Dataset:所有继承他的子类都应该重写  __len()__  , __getitem()__ 这两个方法 __len()__ :返回数据集中数据的数量 __getitem()__ :返回支持下标索引方式获取的一个数据 torch.utils.data.DataLoad

  • PyTorch使用cpu加载模型运算方式

    没gpu没cuda支持的时候加载模型到cpu上计算 将 model = torch.load(path, map_location=lambda storage, loc: storage.cuda(device)) 改为 model = torch.load(path, map_location='cpu') 然后删掉所有变量后面的.cuda()方法 以上这篇PyTorch使用cpu加载模型运算方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们.

  • pytorch加载自己的图像数据集实例

    之前学习深度学习算法,都是使用网上现成的数据集,而且都有相应的代码.到了自己开始写论文做实验,用到自己的图像数据集的时候,才发现无从下手 ,相信很多新手都会遇到这样的问题. 参考文章https://www.jb51.net/article/177613.htm 下面代码实现了从文件夹内读取所有图片,进行归一化和标准化操作并将图片转化为tensor.最后读取第一张图片并显示. # 数据处理 import os import torch from torch.utils import data fr

  • 使用pytorch加载并读取COCO数据集的详细操作

    目录 环境配置 基础知识:元祖.字典.数组 利用PyTorch读取COCO数据集 利用PyTorch读取自己制作的数据集 如何使用pytorch加载并读取COCO数据集 环境配置基础知识:元祖.字典.数组利用PyTorch读取COCO数据集利用PyTorch读取自己制作的数据集 环境配置 看pytorch入门教程 基础知识:元祖.字典.数组 # 元祖 a = (1, 2) # 字典 b = {'username': 'peipeiwang', 'code': '111'} # 数组 c = [1

  • Pytorch加载部分预训练模型的参数实例

    前言 自从从深度学习框架caffe转到Pytorch之后,感觉Pytorch的优点妙不可言,各种设计简洁,方便研究网络结构修改,容易上手,比TensorFlow的臃肿好多了.对于深度学习的初学者,Pytorch值得推荐.今天主要主要谈谈Pytorch是如何加载预训练模型的参数以及代码的实现过程. 直接加载预选脸模型 如果我们使用的模型和预训练模型完全一样,那么我们就可以直接加载别人的模型,还有一种情况,我们在训练自己模型的过程中,突然中断了,但只要我们保存了之前的模型的参数也可以使用下面的代码直

随机推荐