PyTorch加载自己的数据集实例详解

数据预处理在解决深度学习问题的过程中,往往需要花费大量的时间和精力。 数据处理的质量对训练神经网络来说十分重要,良好的数据处理不仅会加速模型训练, 更会提高模型性能。为解决这一问题,PyTorch提供了几个高效便捷的工具, 以便使用者进行数据处理或增强等操作,同时可通过并行化加速数据加载。

数据集存放大致有以下两种方式:

(1)所有数据集放在一个目录下,文件名上附有标签名,数据集存放格式如下: root/cat_dog/cat.01.jpg

root/cat_dog/cat.02.jpg

........................

root/cat_dog/dog.01.jpg

root/cat_dog/dog.02.jpg

......................

(2)不同类别的数据集放在不同目录下,目录名就是标签,数据集存放格式如下:

root/ants/xxx.png

root/ants/xxy.jpeg

root/ants/xxz.png

................

root/bees/123.jpg

root/bees/nsdf3.png

root/bees/asd932_.png

..................

1.1 对第1种数据集的处理步骤

(1)生成包含各文件名的列表(List)

(2)定义Dataset的一个子类,该子类需要继承Dataset类,查看Dataset类的源码

(3)重写父类Dataset中的两个魔法方法: 一个是: __lent__(self),其功能是len(Dataset),返回Dataset的样本数。 另一个是__getitem__(self,index),其功能假设索引为i,使Dataset[i]返回第i个样本。

(4)使用torch.utils.data.DataLoader加载数据集Dataset.

1.2 实例详解

以下以cat-dog数据集为例,说明如何实现自定义数据集的加载。

1.2.1 数据集结构

所有数据集在cat-dog目录下:

.\cat_dog\cat.01.jpg

.\cat_dog\cat.02.jpg

.\cat_dog\cat.03.jpg

....................

.\cat_dog\dog.01.jpg

.\cat_dog\dog.02.jpg

....................

1.2.2 导入需要用到的模块

from torch.utils.data import DataLoader,Dataset
from skimage import io,transform
import matplotlib.pyplot as plt
import oimport torch
from torchvision import transforms, utils
from PIL import Image
import pandas as pd
import numpy as np
#过滤警告信息
import warnings
warnings.filterwarnings("ignore")

1.2.3定义加载自定义数据的类

class MyDataset(Dataset): #继承Dataset
 def __init__(self, path_dir, transform=None): #初始化一些属性
  self.path_dir = path_dir #文件路径,如'.\data\cat-dog'
  self.transform = transform #对图形进行处理,如标准化、截取、转换等
  self.images = os.listdir(self.path_dir)#把路径下的所有文件放在一个列表中

 def __len__(self):#返回整个数据集的大小
  return len(self.images)

 def __getitem__(self,index):#根据索引index返回图像及标签
  image_index = self.images[index]#根据索引获取图像文件名称
  img_path = os.path.join(self.path_dir, image_index)#获取图像的路径或目录
  img = Image.open(img_path).convert('RGB')# 读取图像

  # 根据目录名称获取图像标签(cat或dog)
  label = img_path.split('\\')[-1].split('.')[0]
  #把字符转换为数字cat-0,dog-1
  label = 1 if 'dog' in label else 0

  if self.transform is not None:
   img = self.transform(img)
  return img,label

1.2.4 实例化类

dataset = MyDataset('.\data\cat-dog',transform=None)
img, label = dataset[0] #将启动魔法方法__getitem__(0)
print(type(img))
<class 'PIL.Image.Image'>

1.2.5 查看图像形状

i=1
for img, label in dataset:
    if i
img的形状(500, 374),label的值0

img的形状(300, 280),label的值0

img的形状(489, 499),label的值0

img的形状(431, 410),label的值0

img的形状(300, 224),label的值0

从上面返回样本的形状来看:

(1)每张图片的大小不一样,如果需要取batch训练的神经网络来说很不友好。

(2)返回样本的数值较大,未归一化至[-1, 1]

为此需要对img进行转换,如何转换?只要使用torchvision中的transforms即可

1.2.6 对图像数据进行处理

这里使用torchvision中的transforms模块

from torchvision import transforms as T
transform = T.Compose([
 T.Resize(224), # 缩放图片(Image),保持长宽比不变,最短边为224像素
 T.CenterCrop(224), # 从图片中间切出224*224的图片
 T.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]
 T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 标准化至[-1, 1],规定均值和标准差
])

1.2.7查看处理后的数据

dataset = MyDataset('.\data\cat-dog',transform=transform)
for img, label in dataset:
 print("图像img的形状{},标签label的值{}".format(img.shape, label))
 print("图像数据预处理后:\n",img)
 break

图像img的形状torch.Size([3, 224, 224]),标签label的值0

图像数据预处理后:

tensor([[[ 0.9059, 0.9137, 0.9137, ..., 0.9451, 0.9451, 0.9451],

[ 0.9059, 0.9137, 0.9137, ..., 0.9451, 0.9451, 0.9451],

[ 0.9059, 0.9137, 0.9137, ..., 0.9529, 0.9529, 0.9529],

...,

[-0.4824, -0.5294, -0.5373, ..., -0.9216, -0.9294, -0.9451],

[-0.4980, -0.5529, -0.5608, ..., -0.9294, -0.9373, -0.9529],

[-0.4980, -0.5529, -0.5686, ..., -0.9529, -0.9608, -0.9608]],

[[ 0.5686, 0.5765, 0.5765, ..., 0.7961, 0.7882, 0.7882],

[ 0.5686, 0.5765, 0.5765, ..., 0.7961, 0.7882, 0.7882],

[ 0.5686, 0.5765, 0.5765, ..., 0.8039, 0.7961, 0.7961],

...,

[-0.6078, -0.6471, -0.6549, ..., -0.9137, -0.9216, -0.9373],

[-0.6157, -0.6706, -0.6784, ..., -0.9216, -0.9294, -0.9451],

[-0.6157, -0.6706, -0.6863, ..., -0.9451, -0.9529, -0.9529]],

[[-0.0510, -0.0431, -0.0431, ..., 0.2078, 0.2157, 0.2157],

[-0.0510, -0.0431, -0.0431, ..., 0.2078, 0.2157, 0.2157],

[-0.0510, -0.0431, -0.0431, ..., 0.2157, 0.2235, 0.2235],

...,

[-0.9529, -0.9843, -0.9922, ..., -0.9529, -0.9608, -0.9765],

[-0.9686, -0.9922, -1.0000, ..., -0.9608, -0.9686, -0.9843],

[-0.9686, -0.9922, -1.0000, ..., -0.9843, -0.9922, -0.9922]]])

由此可知,数据已标准化、规范化。

1.2.8对数据集进行批量加载

使用DataLoader模块,对数据集dataset进行批量加载

#使用DataLoader加载数据
dataloader = DataLoader(dataset,batch_size=4,shuffle=True)
for batch_datas, batch_labels in dataloader:
 print(batch_datas.size(),batch_labels.size())
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([2, 3, 224, 224]) torch.Size([2])

1.2.9随机查看一个批次的图像

import torchvision
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
# 显示图像
def imshow(img):
 img = img / 2 + 0.5  # unnormalize
 npimg = img.numpy()
 plt.imshow(np.transpose(npimg, (1, 2, 0)))
 plt.show()
# 随机获取部分训练数据
dataiter = iter(dataloader)
images, labels = dataiter.next()
# 显示图像
imshow(torchvision.utils.make_grid(images))
# 打印标签
print(' '.join('%s' % ["小狗" if labels[j].item()==1 else "小猫" for j in range(4)]))

2 对第2种数据集的处理

处理这种情况比较简单,可分为2步:

(1)使用datasets.ImageFolder读取、处理图像。

(2)使用.data.DataLoader批量加载数据集,示例如下:

import torch
from torchvision import transforms, datasets
data_transform = transforms.Compose([
  transforms.RandomSizedCrop(224),
  transforms.RandomHorizontalFlip(),
  transforms.ToTensor(),
  transforms.Normalize(mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])
 ])
hymenoptera_dataset = datasets.ImageFolder(root='.\catdog\train',
           transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset, 

总结

到此这篇关于PyTorch加载自己的数据集实例详解的文章就介绍到这了,更多相关PyTorch加载 数据集内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • pytorch 数据集图片显示方法

    图片显示 pytorch 载入的数据集是元组tuple 形式,里面包括了数据及标签(train_data,label),其中的train_data数据可以转换为torch.Tensor形式,方便后面计算使用. 同样给一些刚入门的同学在使用载入的数据显示图片的时候带来一些难以理解的地方,这里主要是将Tensor与numpy转换的过程,理解了这些就可以就行转换了 CIAFA10数据集 首先载入数据集,这里做了一些数据处理,包括图片尺寸.数据归一化等 import torch from torch.a

  • pytorch 自定义数据集加载方法

    pytorch 官网给出的例子中都是使用了已经定义好的特殊数据集接口来加载数据,而且其使用的数据都是官方给出的数据.如果我们有自己收集的数据集,如何用来训练网络呢?此时需要我们自己定义好数据处理接口.幸运的是pytroch给出了一个数据集接口类(torch.utils.data.Dataset),可以方便我们继承并实现自己的数据集接口. torch.utils.data torch的这个文件包含了一些关于数据集处理的类. class torch.utils.data.Dataset: 一个抽象类

  • pytorch 把MNIST数据集转换成图片和txt的方法

    本文介绍了pytorch 把MNIST数据集转换成图片和txt的方法,分享给大家,具体如下: 1.下载Mnist 数据集 import os # third-party library import torch import torch.nn as nn from torch.autograd import Variable import torch.utils.data as Data import torchvision import matplotlib.pyplot as plt # t

  • pytorch + visdom CNN处理自建图片数据集的方法

    环境 系统:win10 cpu:i7-6700HQ gpu:gtx965m python : 3.6 pytorch :0.3 数据下载 来源自Sasank Chilamkurthy 的教程: 数据:下载链接. 下载后解压放到项目根目录: 数据集为用来分类 蚂蚁和蜜蜂.有大约120个训练图像,每个类有75个验证图像. 数据导入 可以使用 torchvision.datasets.ImageFolder(root,transforms) 模块 可以将 图片转换为 tensor. 先定义transf

  • PyTorch读取Cifar数据集并显示图片的实例讲解

    首先了解一下需要的几个类所在的package from torchvision import transforms, datasets as ds from torch.utils.data import DataLoader import matplotlib.pyplot as plt import numpy as np #transform = transforms.Compose是把一系列图片操作组合起来,比如减去像素均值等. #DataLoader读入的数据类型是PIL.Image

  • pytorch中如何使用DataLoader对数据集进行批处理的方法

    最近搞了搞minist手写数据集的神经网络搭建,一个数据集里面很多个数据,不能一次喂入,所以需要分成一小块一小块喂入搭建好的网络. pytorch中有很方便的dataloader函数来方便我们进行批处理,做了简单的例子,过程很简单,就像把大象装进冰箱里一共需要几步? 第一步:打开冰箱门. 我们要创建torch能够识别的数据集类型(pytorch中也有很多现成的数据集类型,以后再说). 首先我们建立两个向量X和Y,一个作为输入的数据,一个作为正确的结果: 随后我们需要把X和Y组成一个完整的数据集,

  • 详解PyTorch手写数字识别(MNIST数据集)

    MNIST 手写数字识别是一个比较简单的入门项目,相当于深度学习中的 Hello World,可以让我们快速了解构建神经网络的大致过程.虽然网上的案例比较多,但还是要自己实现一遍.代码采用 PyTorch 1.0 编写并运行. 导入相关库 import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, t

  • PyTorch加载自己的数据集实例详解

    数据预处理在解决深度学习问题的过程中,往往需要花费大量的时间和精力. 数据处理的质量对训练神经网络来说十分重要,良好的数据处理不仅会加速模型训练, 更会提高模型性能.为解决这一问题,PyTorch提供了几个高效便捷的工具, 以便使用者进行数据处理或增强等操作,同时可通过并行化加速数据加载. 数据集存放大致有以下两种方式: (1)所有数据集放在一个目录下,文件名上附有标签名,数据集存放格式如下: root/cat_dog/cat.01.jpg root/cat_dog/cat.02.jpg ...

  • 微信小程序JS加载esmap地图的实例详解

    一.在微信小程序里显示室内三维地图 需要满足的两个条件 调用ESMap室内地图需要用到小程序web-view组件,想要通过 web-view 调用ESMap室内地图需要满足以下 2 个条件: 1. 小程序是企业主体,微信 web-view 组件不对个人类型的小程序开放. 2. 您需要有一个自己的域名,在嵌入网页的时候需要在微信后台验证域名(只有自己域名下的网页才能被正确地显示哦,不能随便找一个公开链接). 二.具体实现步骤 1.域名验证: 由于微信平台的规定,web-view 指向的地址,必须是

  • pytorch加载自己的数据集源码分享

    目录 一.标准的数据集流程梳理 数据来源 二.实现加载自己的数据集 1. 保存在txt文件中(生成训练集和测试集,其实这里的训练集以及测试集也都是用文本文件的形式保存下来的) 2. 在继承dataset类LoadData的三个函数里调用train.txt以及test.txt实现相关功能 三.源码 一.标准的数据集流程梳理 分为几个步骤数据准备以及加载数据库–>数据加载器的调用或者设计–>批量调用进行训练或者其他作用 数据来源 直接读取了x和y的数据变量,对比后面的就从把对应的路径写进了文本文件

  • SQLserver中cube:多维数据集实例详解

    1.cube:生成多维数据集,包含各维度可能组合的交叉表格,使用with 关键字连接 with cube 根据需要使用union all 拼接 判断 某一列的null值来自源数据还是 cube 使用GROUPING关键字 GROUPING([档案号]) = 1 : null值来自cube(代表所有的档案号) GROUPING([档案号]) = 0 : null值来自源数据 举例: SELECT * INTO ##GET FROM (SELECT * FROM ( SELECT CASE WHEN

  • Base64加解密的实现方式实例详解

    Base64加解密的实现方式实例详解 本实现方式基于JDK 1.8 实现: import java.util.Base64; import java.util.Base64.Decoder; import java.util.Base64.Encoder; public class Main { static String src = "hello,sahadev"; public static void main(String[] args) { // 获取加密对象 Encoder

  • Three.js加载外部模型的教程详解

    1.  首先我们要在官网: https://threejs.org/ 下载我们three.js压缩包,并将其中的build文件夹下的three.js通过script标签对的src属性导入到我们的页面中 2.  创建three.js核心对象 Scene(场景) Camera(相机) Light(光源) Mesh(模型) Renderer(渲染器) 最后一步就是渲染显示在我们的页面上了renderer.render(scene,camera) 3.  OBJ模型的导入 <script type=&quo

  • JS 加载性能Tree Shaking优化详解

    目录 正文 什么是 Tree Shaking 寻找 Tree Shaking 的机会 防止 Babel 将 ES6 模块转换为 CommonJS 模块 留意 side effects 只导入你需要的 更复杂的情况 总结 正文 随着 web 应用复杂性增加,JS 代码文件的大小也在不断的攀升,截住 2021年9月,在 httparchive 上有统计显示——在移动设备上 JS 传输大小大约为 447 KB,桌面端 JS 传输大小大约为 495 KB,注意这仅仅是在网络中传输的 JS 文件大小,JS

  • vue+webpack实现异步加载三种用法示例详解

    1.第一例 const Home = resolve => { import("@/components/home/home.vue").then( module => { resolve(module) } } 注:(上面import的时候可以不写后缀) export default [{ path: '/home', name:'home', component: Home, meta: { requireAuth: true, // 添加该属性可以判断出该页面是否需要

  • 如何利用预加载优化Laravel Model查询详解

    前言 本文主要给大家介绍了关于利用预加载优化Laravel Model查询的相关内容,分享出来供大家参考学习,话不多说了,来一起看看详细的介绍: 介绍 对象关系映射(ORM)使数据库的工作变得非常简单. 在以面向对象的方式定义数据库关系时,可以轻松查询相关的模型数据,开发人员可能不会注意底层数据库调用. 下面将通过一些例子,进一步帮助您了解如何优化查询. 假设您从数据库收到了100个对象,并且每个记录都有1个关联模型(即belongsTo). 默认使用ORM将产生101个查询; 如下所示: //

  • JavaScript前端图片加载管理器imagepool使用详解

    前言 imagepool是一款管理图片加载的JS工具,通过imagepool可以控制图片并发加载个数. 对于图片加载,最原始的方式就是直接写个img标签,比如:<img src="图片url" />. 经过不断优化,出现了图片延迟加载方案,这回图片的URL不直接写在src属性中,而是写在某个属性中,比如:<img src="" data-src="图片url" />.这样浏览器就不会自动加载图片,等到一个恰当的时机需要加载

随机推荐