Pytorch自定义Dataset和DataLoader去除不存在和空数据的操作

【源码GitHub地址】:点击进入

1. 问题描述

之前写了一篇关于《pytorch Dataset, DataLoader产生自定义的训练数据》的博客,但存在一个问题,我们不能在Dataset做一些数据清理,如果我们传递给Dataset数据,本身存在问题,那么迭代过程肯定出错的。

比如我把很多图片路径都传递给Dataset,如果图片路径都是正确的,且图片都存在也没有损坏,那显然运行是没有问题的;

但倘若传递给Dataset的图片路径有些图片是不存在,这时你通过Dataset读取图片数据,然后再迭代返回,就会出现类似如下的错误:

File "D:\ProgramData\Anaconda3\envs\pytorch-py36\lib\site-packages\torch\utils\data\_utils\collate.py", line 68, in <listcomp> return [default_collate(samples) for samples in transposed]

File "D:\ProgramData\Anaconda3\envs\pytorch-py36\lib\site-packages\torch\utils\data\_utils\collate.py", line 70, in default_collate

raise TypeError((error_msg_fmt.format(type(batch[0])))) TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'NoneType'>

2. 一般的解决方法

一般的解决方法也很简单粗暴,就是在传递数据给Dataset前,就做数据清理,把不存在的图片,损坏的数据都提前清理掉。

是的,这个是最简单粗暴的。

3. 另一种解决方法:自定义返回数据的规则:collate_fn()校对函数

我们希望不管传递什么处理给Dataset,Dataset都进行处理,如果不存在或者异常,就返回None,而在DataLoader时,对于不存为None的数据,都去除掉。

这样就保证在迭代过程中,DataLoader获得batch数据都是正确的。

比如读取batch_size=5的图片数据,如果其中有1个(或者多个)图片是不存在,那么返回的batch应该把不存在的数据过滤掉,即返回5-1=4大小的batch的数据。

是的,我要实现的就是这个功能:返回的batch数据会自定清理掉不合法的数据。

3.1 Pytorch数据处理函数:Dataset和 DataLoader

Pytorch有两个数据处理函数:Dataset和 DataLoader

from torch.utils.data import Dataset, DataLoader

其中Dataset用于定义数据的读取和预处理操作,而DataLoader用于加载并产生批训练数据。

torch.utils.data.DataLoader参数说明:

DataLoader(object)可用参数:

1、dataset(Dataset) 传入的数据集

2、batch_size(int, optional) 每个batch有多少个样本

3、shuffle(bool, optional) 在每个epoch开始的时候,对数据进行重新排序

4、sampler(Sampler, optional) 自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False

5、batch_sampler(Sampler, optional) 与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)

6、num_workers (int, optional) 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)

7、collate_fn (callable, optional) 将一个list的sample组成一个mini-batch的函数

8、pin_memory (bool, optional) 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.

9、drop_last (bool, optional) 如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了。 如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。

10、timeout(numeric, optional) 如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0

11、worker_init_fn (callable, optional) 每个worker初始化函数 If not None, this will be called on eachworker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)

我们要用到的是collate_fn()回调函数

3.2 自定义collate_fn()函数:

torch.utils.data.DataLoader的collate_fn()用于设置batch数据拼接方式,默认是default_collate函数,但当batch中含有None等数据时,默认的default_collate校队方法会出现错误。因此,我们需要自定义collate_fn()函数:

方法也很简单:只需在原来的default_collate函数中添加下面几句代码:判断image是否为None,如果为None,则在原来的batch中清除掉,这样就可以在迭代中避免出错了。

 # 这里添加:判断image是否为None,如果为None,则在原来的batch中清除掉,这样就可以在迭代中避免出错了
 if isinstance(batch, list):
 batch = [(image, image_id) for (image, image_id) in batch if image is not None]
 if batch==[]:
 return (None,None)

dataset_collate.py:

# -*-coding: utf-8 -*-
"""
 @Project: pytorch-learning-tutorials
 @File : dataset_collate.py
 @Author : panjq
 @E-mail : pan_jinquan@163.com
 @Date : 2019-06-07 17:09:13
"""

r""""Contains definitions of the methods used by the _DataLoaderIter workers to
collate samples fetched from dataset into Tensor(s).
These **needs** to be in global scope since Py2 doesn't support serializing
static methods.
"""
import torch
import re
from torch._six import container_abcs, string_classes, int_classes
_use_shared_memory = False
r"""Whether to use shared memory in default_collate"""

np_str_obj_array_pattern = re.compile(r'[SaUO]')

error_msg_fmt = "batch must contain tensors, numbers, dicts or lists; found {}"

numpy_type_map = {
 'float64': torch.DoubleTensor,
 'float32': torch.FloatTensor,
 'float16': torch.HalfTensor,
 'int64': torch.LongTensor,
 'int32': torch.IntTensor,
 'int16': torch.ShortTensor,
 'int8': torch.CharTensor,
 'uint8': torch.ByteTensor,
}

def collate_fn(batch):
 '''
 collate_fn (callable, optional): merges a list of samples to form a mini-batch.
 该函数参考touch的default_collate函数,也是DataLoader的默认的校对方法,当batch中含有None等数据时,
 默认的default_collate校队方法会出现错误
 一种的解决方法是:
 判断batch中image是否为None,如果为None,则在原来的batch中清除掉,这样就可以在迭代中避免出错了
 :param batch:
 :return:
 '''
 r"""Puts each data field into a tensor with outer dimension batch size"""
 # 这里添加:判断image是否为None,如果为None,则在原来的batch中清除掉,这样就可以在迭代中避免出错了
 if isinstance(batch, list):
 batch = [(image, image_id) for (image, image_id) in batch if image is not None]
 if batch==[]:
 return (None,None)

 elem_type = type(batch[0])
 if isinstance(batch[0], torch.Tensor):
 out = None
 if _use_shared_memory:
  # If we're in a background process, concatenate directly into a
  # shared memory tensor to avoid an extra copy
  numel = sum([x.numel() for x in batch])
  storage = batch[0].storage()._new_shared(numel)
  out = batch[0].new(storage)
 return torch.stack(batch, 0, out=out)
 elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
  and elem_type.__name__ != 'string_':
 elem = batch[0]
 if elem_type.__name__ == 'ndarray':
  # array of string classes and object
  if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
  raise TypeError(error_msg_fmt.format(elem.dtype))

  return collate_fn([torch.from_numpy(b) for b in batch])
 if elem.shape == (): # scalars
  py_type = float if elem.dtype.name.startswith('float') else int
  return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
 elif isinstance(batch[0], float):
 return torch.tensor(batch, dtype=torch.float64)
 elif isinstance(batch[0], int_classes):
 return torch.tensor(batch)
 elif isinstance(batch[0], string_classes):
 return batch
 elif isinstance(batch[0], container_abcs.Mapping):
 return {key: collate_fn([d[key] for d in batch]) for key in batch[0]}
 elif isinstance(batch[0], tuple) and hasattr(batch[0], '_fields'): # namedtuple
 return type(batch[0])(*(collate_fn(samples) for samples in zip(*batch)))
 elif isinstance(batch[0], container_abcs.Sequence):
 transposed = zip(*batch)#ok
 return [collate_fn(samples) for samples in transposed]

 raise TypeError((error_msg_fmt.format(type(batch[0]))))

测试方法:

# -*-coding: utf-8 -*-
"""
 @Project: pytorch-learning-tutorials
 @File : dataset.py
 @Author : panjq
 @E-mail : pan_jinquan@163.com
 @Date : 2019-03-07 18:45:06
"""
import torch
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
from utils import dataset_collate
import os
import cv2
from PIL import Image
def read_image(path,mode='RGB'):
 '''
 :param path:
 :param mode: RGB or L
 :return:
 '''
 return Image.open(path).convert(mode)

class TorchDataset(Dataset):
 def __init__(self, image_id_list, image_dir, resize_height=256, resize_width=256, repeat=1, transform=None):
 '''
 :param filename: 数据文件TXT:格式:imge_name.jpg label1_id labe2_id
 :param image_dir: 图片路径:image_dir+imge_name.jpg构成图片的完整路径
 :param resize_height 为None时,不进行缩放
 :param resize_width 为None时,不进行缩放,
    PS:当参数resize_height或resize_width其中一个为None时,可实现等比例缩放
 :param repeat: 所有样本数据重复次数,默认循环一次,当repeat为None时,表示无限循环<sys.maxsize
 :param transform:预处理
 '''
 self.image_dir = image_dir
 self.image_id_list=image_id_list
 self.len = len(image_id_list)
 self.repeat = repeat
 self.resize_height = resize_height
 self.resize_width = resize_width
 self.transform= transform

 def __getitem__(self, i):
 index = i % self.len
 # print("i={},index={}".format(i, index))
 image_id = self.image_id_list[index]
 image_path = os.path.join(self.image_dir, image_id)
 img = self.load_data(image_path)

 if img is None:
  return None,image_id
 img = self.data_preproccess(img)
 return img,image_id

 def __len__(self):
 if self.repeat == None:
  data_len = 10000000
 else:
  data_len = len(self.image_id_list) * self.repeat
 return data_len

 def load_data(self, path):
 '''
 加载数据
 :param path:
 :param resize_height:
 :param resize_width:
 :param normalization: 是否归一化
 :return:
 '''
 try:
  image = read_image(path)
 except Exception as e:
  image=None
  print(e)
 # image = image_processing.read_image(path)#用opencv读取图像
 return image

 def data_preproccess(self, data):
 '''
 数据预处理
 :param data:
 :return:
 '''
 if self.transform is not None:
  data = self.transform(data)
 return data

if __name__=='__main__':

 resize_height = 224
 resize_width = 224
 image_id_list=["1.jpg","ddd.jpg","111.jpg","3.jpg","4.jpg","5.jpg","6.jpg","7.jpg","8.jpg","9.jpg"]
 image_dir="../dataset/test_images/images"
 # 相关预处理的初始化
 '''class torchvision.transforms.ToTensor把shape=(H,W,C)的像素值范围为[0, 255]的PIL.Image或者numpy.ndarray数据
 # 转换成shape=(C,H,W)的像素数据,并且被归一化到[0.0, 1.0]的torch.FloatTensor类型。
 '''
 train_transform = transforms.Compose([
 transforms.Resize(size=(resize_height, resize_width)),
 # transforms.RandomHorizontalFlip(),#随机翻转图像
 transforms.RandomCrop(size=(resize_height, resize_width), padding=4), # 随机裁剪
 transforms.ToTensor(), # 吧shape=(H,W,C)->换成shape=(C,H,W),并且归一化到[0.0, 1.0]的torch.FloatTensor类型
 # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))#给定均值(R,G,B) 方差(R,G,B),将会把Tensor正则化
 ])

 epoch_num=2 #总样本循环次数
 batch_size=5 #训练时的一组数据的大小
 train_data_nums=10
 max_iterate=int((train_data_nums+batch_size-1)/batch_size*epoch_num) #总迭代次数

 train_data = TorchDataset(image_id_list=image_id_list,
    image_dir=image_dir,
    resize_height=resize_height,
    resize_width=resize_width,
    repeat=1,
    transform=train_transform)
 # 使用默认的default_collate会报错
 # train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False)
 # 使用自定义的collate_fn
 train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False, collate_fn=dataset_collate.collate_fn)

 # [1]使用epoch方法迭代,TorchDataset的参数repeat=1
 for epoch in range(epoch_num):
 for step,(batch_image, batch_label) in enumerate(train_loader):
  if batch_image is None and batch_label is None:
  print("batch_image:{},batch_label:{}".format(batch_image, batch_label))
  continue
  image=batch_image[0,:]
  image=image.numpy()#image=np.array(image)
  image = image.transpose(1, 2, 0) # 通道由[c,h,w]->[h,w,c]
  cv2.imshow("image",image)
  cv2.waitKey(2000)
  print("batch_image.shape:{},batch_label:{}".format(batch_image.shape,batch_label))
  # batch_x, batch_y = Variable(batch_x), Variable(batch_y)

输出结果说明:

batch_size=5,输入图片列表image_id_list=["1.jpg","ddd.jpg","111.jpg","3.jpg","4.jpg","5.jpg","6.jpg","7.jpg","8.jpg","9.jpg"] ,其中"ddd.jpg","111.jpg"是不存在的,resize_width=224,正常情况下返回的数据应该是torch.Size([5, 3, 224, 224]),但由于"ddd.jpg","111.jpg"不存在,被过滤掉了,所以第一个batch的维度变为torch.Size([3, 3, 224, 224])

[Errno 2] No such file or directory: '../dataset/test_images/images\\ddd.jpg'

[Errno 2] No such file or directory: '../dataset/test_images/images\\111.jpg'

batch_image.shape:torch.Size([3, 3, 224, 224]),batch_label:('1.jpg', '3.jpg', '4.jpg')

batch_image.shape:torch.Size([5, 3, 224, 224]),batch_label:('5.jpg', '6.jpg', '7.jpg', '8.jpg', '9.jpg')

[Errno 2] No such file or directory: '../dataset/test_images/images\\ddd.jpg'

[Errno 2] No such file or directory: '../dataset/test_images/images\\111.jpg'

batch_image.shape:torch.Size([3, 3, 224, 224]),batch_label:('1.jpg', '3.jpg', '4.jpg')

batch_image.shape:torch.Size([5, 3, 224, 224]),batch_label:('5.jpg', '6.jpg', '7.jpg', '8.jpg', '9.jpg')

以上为个人经验,希望能给大家一个参考,也希望大家多多支持我们。如有错误或未考虑完全的地方,望不吝赐教。

(0)

相关推荐

  • Pytorch数据读取之Dataset和DataLoader知识总结

    一.前言 确保安装 scikit-image numpy 二.Dataset 一个例子: # 导入需要的包 import torch import torch.utils.data.dataset as Dataset import numpy as np # 编造数据 Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]]) Label = np.asarray([[0], [1], [0], [2]]) # 数据[1,2],对应的标签是[0],数据

  • pytorch Dataset,DataLoader产生自定义的训练数据案例

    1. torch.utils.data.Dataset datasets这是一个pytorch定义的dataset的源码集合.下面是一个自定义Datasets的基本框架,初始化放在__init__()中,其中__getitem__()和__len__()两个方法是必须重写的. __getitem__()返回训练数据,如图片和label,而__len__()返回数据长度. class CustomDataset(data.Dataset):#需要继承data.Dataset def __init_

  • 解决pytorch load huge dataset(大数据加载)

    问题 最近用pytorch做实验时,遇到加载大量数据的问题.实验数据大小在400Gb,而本身机器的memory只有256Gb,显然无法将数据一次全部load到memory. 解决方法 首先自定义一个MyDataset继承torch.utils.data.Dataset,然后将MyDataset的对象feed in torch.utils.data.DataLoader()即可. MyDataset在__init__中声明一个文件对象,然后在__getitem__中缓慢读取数据,这样就不会一次把所

  • pytorch中的dataset用法详解

    目录 1.torch.utils.data 里面的dataset使用方法 2.torchvision.datasets的使用方法 用法1:使用官方数据集 用法2:ImageFolder通用的自己数据集加载器 1.torch.utils.data 里面的dataset使用方法 当我们继承了一个 Dataset类之后,我们需要重写 len 方法,该方法提供了dataset的大小: getitem 方法, 该方法支持从 0 到 len(self)的索引 from torch.utils.data im

  • PyTorch实现重写/改写Dataset并载入Dataloader

    前言 众所周知,Dataset和Dataloder是pytorch中进行数据载入的部件.必须将数据载入后,再进行深度学习模型的训练.在pytorch的一些案例教学中,常使用torchvision.datasets自带的MNIST.CIFAR-10数据集,一般流程为: # 下载并存放数据集 train_dataset = torchvision.datasets.CIFAR10(root="数据集存放位置",download=True) # load数据 train_loader = t

  • Pytorch关于Dataset 的数据处理

    目录 Pytorch系列是了解与使用Pytorch编程来实现卷积神经网络. 学习如何对卷积神经网络编程:首先,需要了解Pytorch对数据的使用(也是在我们模型流程中对数据的预处理部分),其中有两个包Dataset,DataLoader.Dataset是Pytorch对于单个数据的处理类似于给一堆数据进行编号,(在有标签的图像处理中)对其有序地提取图像与标签, 而DataLoader则是一坨一坨的数据进行批次的处理. 此实验运用的数据是北邮邓伟洪老师的人脸表情包的数据集, 当然大家也可以自己手动

  • Pytorch自定义Dataset和DataLoader去除不存在和空数据的操作

    [源码GitHub地址]:点击进入 1. 问题描述 之前写了一篇关于<pytorch Dataset, DataLoader产生自定义的训练数据>的博客,但存在一个问题,我们不能在Dataset做一些数据清理,如果我们传递给Dataset数据,本身存在问题,那么迭代过程肯定出错的. 比如我把很多图片路径都传递给Dataset,如果图片路径都是正确的,且图片都存在也没有损坏,那显然运行是没有问题的: 但倘若传递给Dataset的图片路径有些图片是不存在,这时你通过Dataset读取图片数据,然后

  • PyTorch 解决Dataset和Dataloader遇到的问题

    今天在使用PyTorch中Dataset遇到了一个问题.先看代码 class psDataset(Dataset): def __init__(self, x, y, transforms = None): super(Dataset, self).__init__() self.x = x self.y = y if transforms == None: self.transforms = Compose([Resize((224, 224)), ToTensor()]) else: sel

  • PyTorch Dataset与DataLoader使用超详细讲解

    目录 一.Dataset 1. 在控制台进行操作 ①获取图片的基本信息 ②获取文件的基本信息 2. 编写一个继承Dataset 的类加载数据 ①定义 MyData类 ②创建类的实例并调用 二.DataLoader 一.Dataset Dataset 类提供一种方式去获取数据及其标签 主要有两个目的: 获取每一个数据及其标签 获取数据的总量大小 1. 在控制台进行操作 Hymenoptera (膜翅目昆虫)数据集下载地址: 链接: https://pan.baidu.com/s/1XKwXsAtE

  • pytorch 自定义数据集加载方法

    pytorch 官网给出的例子中都是使用了已经定义好的特殊数据集接口来加载数据,而且其使用的数据都是官方给出的数据.如果我们有自己收集的数据集,如何用来训练网络呢?此时需要我们自己定义好数据处理接口.幸运的是pytroch给出了一个数据集接口类(torch.utils.data.Dataset),可以方便我们继承并实现自己的数据集接口. torch.utils.data torch的这个文件包含了一些关于数据集处理的类. class torch.utils.data.Dataset: 一个抽象类

  • pytorch 实现多个Dataloader同时训练

    看代码吧~ 如果两个dataloader的长度不一样,那就加个: from itertools import cycle 仅使用zip,迭代器将在长度等于最小数据集的长度时耗尽. 但是,使用cycle时,我们将再次重复最小的数据集,除非迭代器查看最大数据集中的所有样本. 补充:pytorch技巧:自定义数据集 torch.utils.data.DataLoader 及Dataset的使用 本博客中有可直接运行的例子,便于直观的理解,在torch环境中运行即可. 1. 数据传递机制 在 pytor

  • Pytorch使用技巧之Dataloader中的collate_fn参数详析

    以MNIST为例 from torchvision import datasets mnist = datasets.MNIST(root='./data/', train=True, download=True) print(mnist[0]) 结果 (<PIL.Image.Image image mode=L size=28x28 at 0x196E3F1D898>, 5) MINIST数据集的dataset是由一张图片和一个label组成的元组 dataloader = torch.ut

  • python机器学习pytorch自定义数据加载器

    目录 正文 1. 加载数据集 2. 迭代和可视化数据集 3.创建自定义数据集 3.1 __init__ 3.2 __len__ 3.3 __getitem__ 4. 使用 DataLoaders 为训练准备数据 5.遍历 DataLoader 正文 处理数据样本的代码可能会逐渐变得混乱且难以维护:理想情况下,我们希望我们的数据集代码与我们的模型训练代码分离,以获得更好的可读性和模块化.PyTorch 提供了两个数据原语:torch.utils.data.DataLoader和torch.util

  • Pytorch自定义CNN网络实现猫狗分类详解过程

    目录 前言 一. 数据预处理 二. 定义网络 三. 训练模型 前言 数据集下载地址: 链接: https://pan.baidu.com/s/17aglKyKFvMvcug0xrOqJdQ?pwd=6i7m Dogs vs. Cats(猫狗大战)来源Kaggle上的一个竞赛题,任务为给定一个数据集,设计一种算法中的猫狗图片进行判别. 数据集包括25000张带标签的训练集图片,猫和狗各125000张,标签都是以cat or dog命名的.图像为RGB格式jpg图片,size不一样.截图如下: 一.

  • pytorch自定义初始化权重的方法

    在常见的pytorch代码中,我们见到的初始化方式都是调用init类对每层所有参数进行初始化.但是,有时我们有些特殊需求,比如用某一层的权重取优化其它层,或者手动指定某些权重的初始值. 核心思想就是构造和该层权重同一尺寸的矩阵去对该层权重赋值.但是,值得注意的是,pytorch中各层权重的数据类型是nn.Parameter,而不是Tensor或者Variable. import torch import torch.nn as nn import torch.optim as optim imp

随机推荐