python机器学习pytorch自定义数据加载器

目录
  • 正文
  • 1. 加载数据集
  • 2. 迭代和可视化数据集
  • 3.创建自定义数据集
    • 3.1 __init__
    • 3.2 __len__
    • 3.3 __getitem__
  • 4. 使用 DataLoaders 为训练准备数据
  • 5.遍历 DataLoader

正文

处理数据样本的代码可能会逐渐变得混乱且难以维护;理想情况下,我们希望我们的数据集代码与我们的模型训练代码分离,以获得更好的可读性和模块化。PyTorch 提供了两个数据原语:torch.utils.data.DataLoadertorch.utils.data.Dataset 允许我们使用预加载的数据集以及自定义数据。 Dataset存储样本及其对应的标签,DataLoader封装了一个迭代器用于遍历Dataset,以便轻松访问样本数据。

PyTorch 领域库提供了许多预加载的数据集(例如 FashionMNIST),这些数据集继承自torch.utils.data.Dataset并实现了特定于特定数据的功能。它们可用于对您的模型进行原型设计和基准测试。你可以在这里找到它们:图像数据集、 文本数据集和 音频数据集

1. 加载数据集

下面是如何从 TorchVision 加载Fashion-MNIST数据集的示例。Fashion-MNIST 是 Zalando 文章图像的数据集,由 60,000 个训练示例和 10,000 个测试示例组成。每个示例都包含 28×28 灰度图像和来自 10 个类别之一的相关标签。

我们使用以下参数加载FashionMNIST 数据集:

  • root是存储训练/测试数据的路径,
  • train指定训练或测试数据集,
  • download=True如果数据不可用,则从 Internet 下载数据root
  • transformtarget_transform指定特征和标签转换
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
  0%|          | 0/26421880 [00:00<?, ?it/s]
  0%|          | 32768/26421880 [00:00<01:26, 303914.51it/s]
  0%|          | 65536/26421880 [00:00<01:27, 301769.74it/s]
  0%|          | 131072/26421880 [00:00<01:00, 437795.76it/s]
  1%|          | 229376/26421880 [00:00<00:42, 621347.43it/s]
  2%|1         | 491520/26421880 [00:00<00:20, 1259673.64it/s]
  4%|3         | 950272/26421880 [00:00<00:11, 2264911.11it/s]
  7%|7         | 1933312/26421880 [00:00<00:05, 4467299.81it/s]
 15%|#4        | 3833856/26421880 [00:00<00:02, 8587616.55it/s]
 26%|##6       | 6881280/26421880 [00:00<00:01, 14633777.99it/s]
 37%|###7      | 9830400/26421880 [00:01<00:00, 18150145.01it/s]
 49%|####8     | 12910592/26421880 [00:01<00:00, 21161097.17it/s]
 61%|######    | 16023552/26421880 [00:01<00:00, 23366004.89it/s]
 72%|#######2  | 19136512/26421880 [00:01<00:00, 24967488.10it/s]
 84%|########4 | 22249472/26421880 [00:01<00:00, 26016258.24it/s]
 95%|#########5| 25231360/26421880 [00:01<00:00, 26218488.24it/s]
100%|##########| 26421880/26421880 [00:01<00:00, 15984902.80it/s]
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
  0%|          | 0/29515 [00:00<?, ?it/s]
100%|##########| 29515/29515 [00:00<00:00, 268356.24it/s]
100%|##########| 29515/29515 [00:00<00:00, 266767.69it/s]
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
  0%|          | 0/4422102 [00:00<?, ?it/s]
  1%|          | 32768/4422102 [00:00<00:14, 302027.13it/s]
  1%|1         | 65536/4422102 [00:00<00:14, 300501.69it/s]
  3%|2         | 131072/4422102 [00:00<00:09, 436941.45it/s]
  5%|5         | 229376/4422102 [00:00<00:06, 619517.19it/s]
 10%|9         | 425984/4422102 [00:00<00:03, 1044158.55it/s]
 20%|##        | 884736/4422102 [00:00<00:01, 2114396.73it/s]
 40%|####      | 1769472/4422102 [00:00<00:00, 4067080.68it/s]
 80%|########  | 3538944/4422102 [00:00<00:00, 7919346.09it/s]
100%|##########| 4422102/4422102 [00:00<00:00, 5036535.17it/s]
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
  0%|          | 0/5148 [00:00<?, ?it/s]
100%|##########| 5148/5148 [00:00<00:00, 22168662.21it/s]
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

2. 迭代和可视化数据集

我们可以像python 列表一样索引Datasets,比如:

training_data[index].

我们用matplotlib来可视化训练数据中的一些样本。

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

3.创建自定义数据集

自定义 Dataset 类必须实现三个函数:initlen__和__getitem

比如: FashionMNIST 图像存储在一个目录img_dir中,它们的标签分别存储在一个 CSV 文件annotations_file中。

在接下来的部分中,我们将分析每个函数中发生的事情。

import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
    def __len__(self):
        return len(self.img_labels)
    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

3.1 __init__

init 函数在实例化 Dataset 对象时运行一次。我们初始化包含图像、注释文件和两种转换的目录(在下一节中更详细地介绍)。

labels.csv 文件如下所示:

tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
    self.img_labels = pd.read_csv(annotations_file)
    self.img_dir = img_dir
    self.transform = transform
    self.target_transform = target_transform

3.2 __len__

len 函数返回我们数据集中的样本数。

例子:

def __len__(self):
    return len(self.img_labels)

3.3 __getitem__

getitem 函数从给定索引处的数据集中加载并返回一个样本idx。基于索引,它识别图像在磁盘上的位置,使用 将其转换为张量read_image,从 csv 数据中检索相应的标签self.img_labels,调用它们的转换函数(如果适用),并返回张量图像和相应的标签一个元组。

def __getitem__(self, idx):
    img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
    image = read_image(img_path)
    label = self.img_labels.iloc[idx, 1]
    if self.transform:
        image = self.transform(image)
    if self.target_transform:
        label = self.target_transform(label)
    return image, label

4. 使用 DataLoaders 为训练准备数据

Dataset一次加载一个样本数据和其对应的label。在训练模型时,我们通常希望以minibatches“小批量”的形式传递样本,在每个 epoch 重新洗牌以减少模型过拟合,并使用 Pythonmultiprocessing加速数据检索。

DataLoader是一个可迭代对象,它封装了复杂性并暴漏了简单的API。

from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

5.遍历 DataLoader

我们已将该数据集加载到 DataLoader中,并且可以根据需要遍历数据集。下面的每次迭代都会返回一批train_featurestrain_labels(分别包含batch_size=64特征和标签)。因为我们指定shuffle=True了 ,所以在我们遍历所有批次之后,数据被打乱(为了更细粒度地控制数据加载顺序,请查看Samplers)。

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 4

以上就是python机器学习pytorch自定义数据加载器的详细内容,更多关于python pytorch自定义数据加载器的资料请关注我们其它相关文章!

(0)

相关推荐

  • Python基于Pytorch的特征图提取实例

    目录 简述 单个图片的提取 神经网络的构建 特征图的提取 可视化展示 完整代码 总结 简述 为了方便理解卷积神经网络的运行过程,需要对卷积神经网络的运行结果进行可视化的展示. 大致可分为如下步骤: 单个图片的提取 神经网络的构建 特征图的提取 可视化展示 单个图片的提取 根据目标要求,需要对单个图片进行卷积运算,但是Pytorch中读取数据主要用到torch.utils.data.DataLoader类,因此我们需要编写单个图片的读取程序 def get_picture(picture_dir,

  • 深入学习PyTorch中LSTM的输入和输出

    目录 LSTM参数 Inputs Outputs 案例 LSTM参数 官方文档给出的解释为: 总共有七个参数,其中只有前三个是必须的.由于大家普遍使用PyTorch的DataLoader来形成批量数据,因此batch_first也比较重要.LSTM的两个常见的应用场景为文本处理和时序预测,因此下面对每个参数我都会从这两个方面来进行具体解释. input_size:在文本处理中,由于一个单词没法参与运算,因此我们得通过Word2Vec来对单词进行嵌入表示,将每一个单词表示成一个向量,此时input

  • Python Pytorch学习之图像检索实践

    目录 背景 图像表现 搜索 随着电子商务和在线网站的出现,图像检索在我们的日常生活中的应用一直在增加. 亚马逊.阿里巴巴.Myntra等公司一直在大量利用图像检索技术.当然,只有当通常的信息检索技术失败时,图像检索才会开始工作. 背景 图像检索的基本本质是根据查询图像的特征从集合或数据库中查找图像. 大多数情况下,这种特征是图像之间简单的视觉相似性.在一个复杂的问题中,这种特征可能是两幅图像在风格上的相似性,甚至是互补性. 由于原始形式的图像不会在基于像素的数据中反映这些特征,因此我们需要将这些

  • python神经网络学习利用PyTorch进行回归运算

    目录 学习前言 PyTorch中的重要基础函数 1.class Net(torch.nn.Module)神经网络的构建: 2.optimizer优化器 3.loss损失函数定义 4.训练过程 全部代码 学习前言 我发现不仅有很多的Keras模型,还有很多的PyTorch模型,还是学学Pytorch吧,我也想了解以下tensor到底是个啥. PyTorch中的重要基础函数 1.class Net(torch.nn.Module)神经网络的构建: PyTorch中神经网络的构建和Tensorflow

  • python神经网络pytorch中BN运算操作自实现

    BN 想必大家都很熟悉,来自论文: <Batch Normalization Accelerating Deep Network Training by Reducing Internal Covariate Shift> 也是面试常考察的内容,虽然一行代码就能搞定,但是还是很有必要用代码自己实现一下,也可以加深一下对其内部机制的理解. 通用公式: 直奔代码: 首先是定义一个函数,实现BN的运算操作: def batch_norm(is_training, x, gamma, beta, mo

  • python神经网络Pytorch中Tensorboard函数使用

    目录 所需库的安装 常用函数功能 1.SummaryWriter() 2.writer.add_graph() 3.writer.add_scalar() 4.tensorboard --logdir= 示例代码 所需库的安装 很多人问Pytorch要怎么可视化,于是决定搞一篇. tensorboardX==2.0 tensorflow==1.13.2 由于tensorboard原本是在tensorflow里面用的,所以需要装一个tensorflow.会自带一个tensorboard. 也可以不

  • python机器学习pytorch自定义数据加载器

    目录 正文 1. 加载数据集 2. 迭代和可视化数据集 3.创建自定义数据集 3.1 __init__ 3.2 __len__ 3.3 __getitem__ 4. 使用 DataLoaders 为训练准备数据 5.遍历 DataLoader 正文 处理数据样本的代码可能会逐渐变得混乱且难以维护:理想情况下,我们希望我们的数据集代码与我们的模型训练代码分离,以获得更好的可读性和模块化.PyTorch 提供了两个数据原语:torch.utils.data.DataLoader和torch.util

  • pytorch 自定义数据集加载方法

    pytorch 官网给出的例子中都是使用了已经定义好的特殊数据集接口来加载数据,而且其使用的数据都是官方给出的数据.如果我们有自己收集的数据集,如何用来训练网络呢?此时需要我们自己定义好数据处理接口.幸运的是pytroch给出了一个数据集接口类(torch.utils.data.Dataset),可以方便我们继承并实现自己的数据集接口. torch.utils.data torch的这个文件包含了一些关于数据集处理的类. class torch.utils.data.Dataset: 一个抽象类

  • Python实现实时增量数据加载工具的解决方案

    目录 创建增量ID记录表 数据库连接类 增量数据服务客户端 结果测试 本次主要分享结合单例模式实际应用案例:实现实时增量数据加载工具的解决方案.最关键的是实现一个可进行添加.修改.删除等操作的增量ID记录表. 单例模式:提供全局访问点,确保类有且只有一个特定类型的对象.通常用于以下场景:日志记录或数据库操作等,避免对用一资源请求冲突. 创建增量ID记录表 import sqlite3 import datetime import pymssql import pandas as pd impor

  • 编写自定义的Django模板加载器的简单示例

    Djangos 内置的模板加载器(在先前的模板加载内幕章节有叙述)通常会满足你的所有的模板加载需求,但是如果你有特殊的加载需求的话,编写自己的模板加载器也会相当简单. 比如:你可以从数据库中,或者利用Python的绑定直接从Subversion库中,更或者从一个ZIP文档中加载模板. 模板加载器,也就是 TEMPLATE_LOADERS 中的每一项,都要能被下面这个接口调用: load_template_source(template_name, template_dirs=None) 参数 t

  • python用pandas数据加载、存储与文件格式的实例

    数据加载.存储与文件格式 pandas提供了一些用于将表格型数据读取为DataFrame对象的函数.其中read_csv和read_talbe用得最多 pandas中的解析函数: 函数 说明 read_csv 从文件.URL.文件型对象中加载带分隔符的数据,默认分隔符为逗号 read_table 从文件.URL.文件型对象中加载带分隔符的数据.默认分隔符为制表符("\t") read_fwf 读取定宽列格式数据(也就是说,没有分隔符) read_clipboard 读取剪贴板中的数据,

  • Pytorch 数据加载与数据预处理方式

    数据加载分为加载torchvision.datasets中的数据集以及加载自己使用的数据集两种情况. torchvision.datasets中的数据集 torchvision.datasets中自带MNIST,Imagenet-12,CIFAR等数据集,所有的数据集都是torch.utils.data.Dataset的子类,都包含 _ _ len _ (获取数据集长度)和 _ getItem _ _ (获取数据集中每一项)两个子方法. Dataset源码如上,可以看到其中包含了两个没有实现的子

  • 解决pytorch load huge dataset(大数据加载)

    问题 最近用pytorch做实验时,遇到加载大量数据的问题.实验数据大小在400Gb,而本身机器的memory只有256Gb,显然无法将数据一次全部load到memory. 解决方法 首先自定义一个MyDataset继承torch.utils.data.Dataset,然后将MyDataset的对象feed in torch.utils.data.DataLoader()即可. MyDataset在__init__中声明一个文件对象,然后在__getitem__中缓慢读取数据,这样就不会一次把所

  • Python实现从文件中加载数据的方法详解

    前几篇都是手动录入或随机函数产生的数据.实际有许多类型的文件,以及许多方法,用它们从文件中提取数据来图形化. 比如之前python基础(12)介绍打开文件的方式,可直接读取文件中的数据,扩大了我们的数据来源.下面,将展示几种方法. 我们将使用内置的 csv 模块加载CSV文件 CSV文件是一种特殊的文本文件,文件中的数据以逗号作为分隔符,很适合进行数据的解析.先用excle建立如下表格和数据,另存为csv格式文件,放到代码目录下. 包含在Python标准库中自带CSV 模块,我们只需要impor

  • Android自定义加载控件实现数据加载动画

    本文实例为大家分享了Android自定义加载控件,第一次小人跑动的加载效果眼前一亮,相比传统的PrograssBar高大上不止一点,于是走起,自定义了控件LoadingView去实现动态效果,可直接在xml中使用,具体实现如下 package com.*****.*****.widget; import android.content.Context; import android.graphics.drawable.AnimationDrawable; import android.util.

  • pytorch 数据加载性能对比分析

    传统方式需要10s,dat方式需要0.6s import os import time import torch import random from common.coco_dataset import COCODataset def gen_data(batch_size,data_path,target_path): os.makedirs(target_path,exist_ok=True) dataloader = torch.utils.data.DataLoader(COCODat

随机推荐