PyTorch Dataset与DataLoader使用超详细讲解

目录
  • 一、Dataset
    • 1. 在控制台进行操作
      • ①获取图片的基本信息
      • ②获取文件的基本信息
    • 2. 编写一个继承Dataset 的类加载数据
      • ①定义 MyData类
      • ②创建类的实例并调用
  • 二、DataLoader

一、Dataset

Dataset 类提供一种方式去获取数据及其标签

主要有两个目的:

  • 获取每一个数据及其标签
  • 获取数据的总量大小

1. 在控制台进行操作

Hymenoptera (膜翅目昆虫)数据集下载地址:

链接: https://pan.baidu.com/s/1XKwXsAtE2yzZW2IsvBDpnw?pwd=8a5t

提取码: 8a5t

这是一个蚂蚁蜜蜂二分类的数据集,通常数据集有以下三种组织形式(上面的数据集属于第一种):

  • 不同的类别以文件夹的形式存在,文件夹中是该类别的图片
  • 图片与标签分别存储,图片在一个文件夹下,label信息在另一个文件夹下
  • label直接写在图片名称里

①获取图片的基本信息

在Pycharm 中,点击下方的PythonConsole进入控制台进行操作(通过控制台可以看到变量的详细信息)

首先加载图片,逐行输入下方代码:

from PIL import Image
img_path = "./dataset/hymenoptera_data/train/ants/0013035.jpg"
img = Image.open(img_path)

此时我们就可以在右侧看到相关变量的信息:

点击img变量,可以查看图片的详细信息。通过控制台执行程序能够直观地获取后续操作所需的数据:

最后可以通过img.show()打开图片查看:

②获取文件的基本信息

同样还是在控制台逐行输入以下代码:

dir_path = "dataset/hymenoptera_data/train/ants"
import os
img_path_list = os.listdir(dir_path)
img_path_list[0]

我们就可以获取到文件夹下的文件名称,由于是使用控制台,我们还可以在右侧查看列表的详细信息:

因此在控制台操作是有很大的优点的,我们可以在控制台逐行执行已经编写好的文件中的语句,通过查看右侧变量的值来判断程序写的是否有问题

2. 编写一个继承Dataset 的类加载数据

下面的代码也可以在控制台运行(可以多行复制粘贴)来检验程序是否有误

①定义 MyData类

导入所需头文件:

from torch.utils.data import Dataset
from PIL import Image
import os

定义MyData类:

  • __init__:初始化函数
  • __getitem__:返回指定下标的图片和标签
  • __len__:返回数据集的大小
class MyData(Dataset):
    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path = os.listdir(self.path)
    def __getitem__(self, idx):
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
        img = Image.open(img_item_path)
        label = self.label_dir
        return img, label
    def __len__(self):
        return len(self.img_path)

其中os.path.join()可以实现多个路径的合并且不出错

②创建类的实例并调用

创建 MyData 类的实例:

if __name__ == "__main__":
    root_dir = "../dataset/hymenoptera_data/train"
    ants_label_dir = "ants"
    bees_label_dir = "bees"
    ants_dataset = MyData(root_dir, ants_label_dir)
    bees_dataset = MyData(root_dir, bees_label_dir)

调用类中写好的函数:

    img, label = ants_dataset.__getitem__(3)
    print(ants_dataset.__len__(), label)
    img.show()

同时我们也可以通过下面这种方式用已有的数据集来创造数据集:

train_dataset = ants_dataset + bees_dataset

二、DataLoader

  • DataLoader 类是为后面的网络提供不同的数据形式
  • DataLoader 会根据batch_size的值对数据进行打包
  • 导入所需的包
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

加载数据:

test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

测试:

img, target = test_data[0]
print(img.shape)
print(target)

进行日志记录,开始训练:

writer = SummaryWriter("dataloader")
for epoch in range(2):
    step = 0
    for data in test_loader:
        imgs, targets = data
        print(imgs.shape)
        print(targets)
        writer.add_images("Epoch: {}".format(epoch), imgs, step)
        step = step + 1
writer.close()

到此这篇关于PyTorch Dataset与DataLoader使用超详细讲解的文章就介绍到这了,更多相关PyTorch Dataset与DataLoader内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • Pytorch数据读取之Dataset和DataLoader知识总结

    一.前言 确保安装 scikit-image numpy 二.Dataset 一个例子: # 导入需要的包 import torch import torch.utils.data.dataset as Dataset import numpy as np # 编造数据 Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]]) Label = np.asarray([[0], [1], [0], [2]]) # 数据[1,2],对应的标签是[0],数据

  • 一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系

    以下内容都是针对Pytorch 1.0-1.1介绍. 很多文章都是从Dataset等对象自下往上进行介绍,但是对于初学者而言,其实这并不好理解,因为有的时候会不自觉地陷入到一些细枝末节中去,而不能把握重点,所以本文将会自上而下地对Pytorch数据读取方法进行介绍. 自上而下理解三者关系 首先我们看一下DataLoader.next的源代码长什么样,为方便理解我只选取了num_works为0的情况(num_works简单理解就是能够并行化地读取数据). class DataLoader(obje

  • Pytorch自定义Dataset和DataLoader去除不存在和空数据的操作

    [源码GitHub地址]:点击进入 1. 问题描述 之前写了一篇关于<pytorch Dataset, DataLoader产生自定义的训练数据>的博客,但存在一个问题,我们不能在Dataset做一些数据清理,如果我们传递给Dataset数据,本身存在问题,那么迭代过程肯定出错的. 比如我把很多图片路径都传递给Dataset,如果图片路径都是正确的,且图片都存在也没有损坏,那显然运行是没有问题的: 但倘若传递给Dataset的图片路径有些图片是不存在,这时你通过Dataset读取图片数据,然后

  • PyTorch实现重写/改写Dataset并载入Dataloader

    前言 众所周知,Dataset和Dataloder是pytorch中进行数据载入的部件.必须将数据载入后,再进行深度学习模型的训练.在pytorch的一些案例教学中,常使用torchvision.datasets自带的MNIST.CIFAR-10数据集,一般流程为: # 下载并存放数据集 train_dataset = torchvision.datasets.CIFAR10(root="数据集存放位置",download=True) # load数据 train_loader = t

  • PyTorch 解决Dataset和Dataloader遇到的问题

    今天在使用PyTorch中Dataset遇到了一个问题.先看代码 class psDataset(Dataset): def __init__(self, x, y, transforms = None): super(Dataset, self).__init__() self.x = x self.y = y if transforms == None: self.transforms = Compose([Resize((224, 224)), ToTensor()]) else: sel

  • pytorch Dataset,DataLoader产生自定义的训练数据案例

    1. torch.utils.data.Dataset datasets这是一个pytorch定义的dataset的源码集合.下面是一个自定义Datasets的基本框架,初始化放在__init__()中,其中__getitem__()和__len__()两个方法是必须重写的. __getitem__()返回训练数据,如图片和label,而__len__()返回数据长度. class CustomDataset(data.Dataset):#需要继承data.Dataset def __init_

  • PyTorch Dataset与DataLoader使用超详细讲解

    目录 一.Dataset 1. 在控制台进行操作 ①获取图片的基本信息 ②获取文件的基本信息 2. 编写一个继承Dataset 的类加载数据 ①定义 MyData类 ②创建类的实例并调用 二.DataLoader 一.Dataset Dataset 类提供一种方式去获取数据及其标签 主要有两个目的: 获取每一个数据及其标签 获取数据的总量大小 1. 在控制台进行操作 Hymenoptera (膜翅目昆虫)数据集下载地址: 链接: https://pan.baidu.com/s/1XKwXsAtE

  • java反射超详细讲解

    目录 Java反射超详解✌ 1.反射基础 1.1Class类 1.2类加载 2.反射的使用 2.1Class对象的获取 2.2Constructor类及其用法 2.4Method类及其用法 Java反射超详解✌ 1.反射基础 Java反射机制是在程序的运行过程中,对于任何一个类,都能够知道它的所有属性和方法:对于任意一个对象,都能够知道它的任意属性和方法,这种动态获取信息以及动态调用对象方法的功能称为Java语言的反射机制. Java反射机制主要提供以下这几个功能: 在运行时判断任意一个对象所属

  • 超详细讲解Linux C++多线程同步的方式

    目录 一.互斥锁 1.互斥锁的初始化 2.互斥锁的相关属性及分类 3,测试加锁函数 二.条件变量 1.条件变量的相关函数 1)初始化的销毁读写锁 2)以写的方式获取锁,以读的方式获取锁,释放读写锁 四.信号量 1)信号量初始化 2)信号量值的加减 3)对信号量进行清理 背景问题:在特定的应用场景下,多线程不进行同步会造成什么问题? 通过多线程模拟多窗口售票为例: #include <iostream> #include<pthread.h> #include<stdio.h&

  • 超详细讲解Linux DHCP服务

    目录 一.DHCP服务(动态主机配置协议) 1.背景 2.概述 3.优点 4.DHCP报文类型 5.DHCP 的分配方式 二.安装 DHCP 服务器 1.DHCP 服务软件 2.主配置文件 三.配置步骤 1.使用 DHCP 动态的给 PC 机分配 IP 地址 ① eNSP ②虚拟机 ③验证 ④进入命令行"ipconfig"测试 一.DHCP服务(动态主机配置协议) 1.背景 1.手动设置工作量大且容易冲突 2.用DHCP可以减少工作量和避免地址冲突 2.概述 作用:为局域网内的电脑分配

  • 超详细讲解python正则表达式

    目录 正则表达式 1.1 正则表达式字符串 1.1.1 元字符 1.1.2 字符转义 1.1.3 开始与结束字符 1.2 字符类 1.2.1 定义字符类 1.2.2 字符串取反 1.2.3 区间 1.2.4 预定义字符类 1.3 量词 1.3.1 量词的使用 1.3.2 贪婪量词和懒惰量词 1.4 分组 1.4.1 分组的使用 1.4.2 分组命名 1.4.3 反向引用分组 1.4.4 非捕获分组 1.5 re模块 1.5.1 search()和match()函数 1.5.2 findall()

  • 超详细讲解Java线程池

    目录 池化技术 池化思想介绍 池化技术的应用 如何设计一个线程池 Java线程池解析 ThreadPoolExecutor使用介绍 内置线程池使用 ThreadPoolExecutor解析 整体设计 线程池生命周期 任务管理解析 woker对象 Java线程池实践建议 不建议使用Exectuors 线程池大小设置 线程池监控 带着问题阅读 1.什么是池化,池化能带来什么好处 2.如何设计一个资源池 3.Java的线程池如何使用,Java提供了哪些内置线程池 4.线程池使用有哪些注意事项 池化技术

  • 超详细讲解Java异常

    目录 一.Java异常架构与异常关键字 Java异常简介 Java异常架构 1.Throwable 2.Error(错误) 3.Exception(异常) 4.受检异常与非受检异常 Java异常关键字 二.Java异常处理 声明异常 抛出异常 捕获异常 如何选择异常类型 常见异常处理方式 1.直接抛出异常 2.封装异常再抛出 3.捕获异常 4.自定义异常 5.try-catch-finally 6.try-with-resource 三.Java异常常见面试题 1.Error 和 Excepti

  • Python 数据可视化超详细讲解折线图的实现

    绘制简单的折线图 在使用matplotlib绘制简单的折线图之前首先需要安装matplotlib,直接在pycharm终端pip install matplotlib即可 使用matplotlib绘制简单的折线图,再对其进行定制,实现数据的可视化操作 import matplotlib.pyplot as plt # 导入pyplot模块并设置别名为plt squares = [1, 4, 9, 16, 25] plt.plot(squares) plt.show() # 打开matplotib

  • C++ 数据结构超详细讲解单链表

    目录 前言 一.链表是什么 链表的分类 二.链表的实现 总结 (❁´◡`❁) 单链表 前言 上篇顺序表结尾了解了顺序表的诸多缺点,链表的特性很好的解决了这些问题,本期我们来认识单链表. 一.链表是什么 链表是一种物理存储结构上非连续.非顺序的存储结构,数据元素的逻辑顺序是通过链表中的指针链接依次实现的. 由图,链式结构在逻辑上是连续的,但是物理上不一定连续 显示中结点一般是从堆上申请出来的 从堆上申请的空间,是按照一定的策略划分的,两次申请的空间,可能连续,可能不连续,见顺序表 链表的分类 链表

  • C++ 数据结构超详细讲解顺序表

    目录 前言 一.顺序表是什么 概念及结构 二.顺序表的实现 顺序表的缺点 几道练手题 总结 (●’◡’●) 前言 线性表是n个具有相同特性的数据元素的有限序列.线性表是一种在实际中广泛使用的数据结构,常见的线性表:顺序表.链表.栈.队列.字符串. 线性表在逻辑上是线性结构,也就是说连续的一条直线,但是在物理结构并不一定是连续的,线性表在物理上存储时,通常以数组和链式结构的形式存储. 本章我们来深度初体验顺序表 一.顺序表是什么 概念及结构 顺序表是一段物理地址连续的存储单元依次存储数据元素的线性

随机推荐