一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系

以下内容都是针对Pytorch 1.0-1.1介绍。

很多文章都是从Dataset等对象自下往上进行介绍,但是对于初学者而言,其实这并不好理解,因为有的时候会不自觉地陷入到一些细枝末节中去,而不能把握重点,所以本文将会自上而下地对Pytorch数据读取方法进行介绍。

自上而下理解三者关系

首先我们看一下DataLoader.next的源代码长什么样,为方便理解我只选取了num_works为0的情况(num_works简单理解就是能够并行化地读取数据)。

class DataLoader(object):
	...

 def __next__(self):
  if self.num_workers == 0:
   indices = next(self.sample_iter) # Sampler
   batch = self.collate_fn([self.dataset[i] for i in indices]) # Dataset
   if self.pin_memory:
    batch = _utils.pin_memory.pin_memory_batch(batch)
   return batch

在阅读上面代码前,我们可以假设我们的数据是一组图像,每一张图像对应一个index,那么如果我们要读取数据就只需要对应的index即可,即上面代码中的indices,而选取index的方式有多种,有按顺序的,也有乱序的,所以这个工作需要Sampler完成,现在你不需要具体的细节,后面会介绍,你只需要知道DataLoader和Sampler在这里产生关系。

那么Dataset和DataLoader在什么时候产生关系呢?没错就是下面一行。我们已经拿到了indices,那么下一步我们只需要根据index对数据进行读取即可了。

再下面的if语句的作用简单理解就是,如果pin_memory=True,那么Pytorch会采取一系列操作把数据拷贝到GPU,总之就是为了加速。

综上可以知道DataLoader,Sampler和Dataset三者关系如下:

在阅读后文的过程中,你始终需要将上面的关系记在心里,这样能帮助你更好地理解。

Sampler

参数传递

要更加细致地理解Sampler原理,我们需要先阅读一下DataLoader 的源代码,如下:

class DataLoader(object):
 def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
     batch_sampler=None, num_workers=0, collate_fn=default_collate,
     pin_memory=False, drop_last=False, timeout=0,
     worker_init_fn=None)

可以看到初始化参数里有两种sampler:samplerbatch_sampler,都默认为None。前者的作用是生成一系列的index,而batch_sampler则是将sampler生成的indices打包分组,得到一个又一个batch的index。例如下面示例中,BatchSamplerSequentialSampler生成的index按照指定的batch size分组。

>>>in : list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
>>>out: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]

Pytorch中已经实现的Sampler有如下几种:

  • SequentialSampler
  • RandomSampler
  • WeightedSampler
  • SubsetRandomSampler

需要注意的是DataLoader的部分初始化参数之间存在互斥关系,这个你可以通过阅读源码更深地理解,这里只做总结:

  • 如果你自定义了batch_sampler,那么这些参数都必须使用默认值:batch_size, shuffle,sampler,drop_last.
  • 如果你自定义了sampler,那么shuffle需要设置为False
  • 如果sampler和batch_sampler都为None,那么batch_sampler使用Pytorch已经实现好的BatchSampler,而sampler分两种情况:
    • 若shuffle=True,则sampler=RandomSampler(dataset)
    • 若shuffle=False,则sampler=SequentialSampler(dataset)

如何自定义Sampler和BatchSampler?

仔细查看源代码其实可以发现,所有采样器其实都继承自同一个父类,即Sampler,其代码定义如下:

class Sampler(object):
 r"""Base class for all Samplers.
 Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
 way to iterate over indices of dataset elements, and a :meth:`__len__` method
 that returns the length of the returned iterators.
 .. note:: The :meth:`__len__` method isn't strictly required by
    :class:`~torch.utils.data.DataLoader`, but is expected in any
    calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
 """

 def __init__(self, data_source):
  pass

 def __iter__(self):
  raise NotImplementedError

 def __len__(self):
  return len(self.data_source)

所以你要做的就是定义好__iter__(self)函数,不过要注意的是该函数的返回值需要是可迭代的。例如SequentialSampler返回的是iter(range(len(self.data_source)))

另外BatchSampler与其他Sampler的主要区别是它需要将Sampler作为参数进行打包,进而每次迭代返回以batch size为大小的index列表。也就是说在后面的读取数据过程中使用的都是batch sampler。

Dataset

Dataset定义方式如下:

class Dataset(object):
	def __init__(self):
		...

	def __getitem__(self, index):
		return ...

	def __len__(self):
		return ...

上面三个方法是最基本的,其中__getitem__是最主要的方法,它规定了如何读取数据。但是它又不同于一般的方法,因为它是python built-in方法,其主要作用是能让该类可以像list一样通过索引值对数据进行访问。假如你定义好了一个dataset,那么你可以直接通过dataset[0]来访问第一个数据。在此之前我一直没弄清楚__getitem__是什么作用,所以一直不知道该怎么进入到这个函数进行调试。现在如果你想对__getitem__方法进行调试,你可以写一个for循环遍历dataset来进行调试了,而不用构建dataloader等一大堆东西了,建议学会使用ipdb这个库,非常实用!!!以后有时间再写一篇ipdb的使用教程。另外,其实我们通过最前面的Dataloader的__next__函数可以看到DataLoader对数据的读取其实就是用了for循环来遍历数据,不用往上翻了,我直接复制了一遍,如下:

class DataLoader(object):
 ... 

 def __next__(self):
  if self.num_workers == 0:
   indices = next(self.sample_iter)
   batch = self.collate_fn([self.dataset[i] for i in indices]) # this line
   if self.pin_memory:
    batch = _utils.pin_memory.pin_memory_batch(batch)
   return batch

我们仔细看可以发现,前面还有一个self.collate_fn方法,这个是干嘛用的呢?在介绍前我们需要知道每个参数的意义:

  • indices: 表示每一个iteration,sampler返回的indices,即一个batch size大小的索引列表
  • self.dataset[i]: 前面已经介绍了,这里就是对第i个数据进行读取操作,一般来说self.dataset[i]=(img, label)

看到这不难猜出collate_fn的作用就是将一个batch的数据进行合并操作。默认的collate_fn是将img和label分别合并成imgs和labels,所以如果你的__getitem__方法只是返回 img, label,那么你可以使用默认的collate_fn方法,但是如果你每次读取的数据有img, box, label等等,那么你就需要自定义collate_fn来将对应的数据合并成一个batch数据,这样方便后续的训练步骤。

到此这篇关于一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系的文章就介绍到这了,更多相关Pytorch DataLoader DataSet Sampler内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • 一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系

    以下内容都是针对Pytorch 1.0-1.1介绍. 很多文章都是从Dataset等对象自下往上进行介绍,但是对于初学者而言,其实这并不好理解,因为有的时候会不自觉地陷入到一些细枝末节中去,而不能把握重点,所以本文将会自上而下地对Pytorch数据读取方法进行介绍. 自上而下理解三者关系 首先我们看一下DataLoader.next的源代码长什么样,为方便理解我只选取了num_works为0的情况(num_works简单理解就是能够并行化地读取数据). class DataLoader(obje

  • 一文弄懂JavaScript的继承方式

    目录 JavaScript中的继承方式 问:JavaScript中有几种继承方式呢 问:每种继承方式是怎么实现的呢 盗用构造函数 组合继承 原型链式继承 寄生式继承 寄生时组合继承 JavaScript中的继承方式 问:JavaScript中有几种继承方式呢 emmm...六种?五种?还是四种来着... 这次记清楚了 一共有五种继承方式 盗用构造函数 (经典继承方式 ) 组合继承 原型链式继承 寄生式继承 寄生式组合继承 问:每种继承方式是怎么实现的呢 盗用构造函数 基本思路很简单:在子类构造函

  • pytorch中dataloader 的sampler 参数详解

    目录 1. dataloader() 初始化函数 2. shuffle 与sample 之间的关系 3. sample 的定义方法 3.1 sampler 参数的使用 4. batch 生成过程 1. dataloader() 初始化函数 def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_mem

  • 让你一文弄懂Pandas文本数据处理

    目录 前言 1. 文本数据类型 1.1. 类型简介 1.2. 类型差异 2. 字符串方法 2.1. 文本格式 2.2. 文本对齐 2.3. 计数与编码 2.4. 格式判断 3. 文本高级操作 3.1. 文本拆分 3.2. 文本替换 3.3. 文本拼接 3.4. 文本匹配 3.5. 文本提取 总结 前言 日常工作中我们经常接触到一些文本类信息,需要从文本中解析出数据信息,然后再进行数据分析操作. 而对文本类信息进行解析是一件比较头秃的事情,好巧,Pandas刚好对这类文本数据有比较好的处理方法,那

  • 一文弄懂Nginx的location匹配的实现

    由于团队在进行前后端分离,前端接管了 Nginx 和 node 层,在日常的工作中,跟 Nginx 打交道的时候挺多的.其中 location 是使用最多和改动最多的地方.之前对 location 的匹配规则是一知半解的.为了搞明白 location 是如何匹配的,特意花了点时间查了些资料,总结此文.希望能给大家带来帮助. 语法规则 location [ = | ~ | ~* | ^~ ] uri { ... } location @name { ... } 语法规则很简单,一个location

  • 一文弄懂C语言如何实现单链表

    目录 一.单链表与顺序表的区别: 一.顺序表: 二.链表 二.关于链表中的一些函数接口的作用及实现 1.头文件里的结构体和函数声明等等 2.创建接口空间 3.尾插尾删 4.头插头删 5.单链表查找 6.中间插入(在pos后面进行插入) 7.中间删除(在pos后面进行删除) 8.单独打印链表和从头到尾打印链表 9.test.c 总结 一.单链表与顺序表的区别: 一.顺序表: 1.内存中地址连续 2.长度可以实时变化 3.不支持随机查找 4.适用于访问大量元素的,而少量需要增添/删除的元素的程序 5

  • 一文弄懂MySQL中redo log与binlog的区别

    目录 前言 1. 什么是redo log? 1.1 redo日志文件名 1.2 影响redo log参数 1.3 redo log大小怎么设置? 2. 什么是binlog 2.1 binlog文件名 2.2 影响binlog的参数 2.3 查看binlog 3. redo log与binlog的区别 总结 前言 MySQL中有六种日志文件,分别是:重做日志(redo log).回滚日志(undo log).二进制日志(binlog).错误日志(errorlog).慢查询日志(slow query

  • 一文弄懂MySQL索引创建原则

    目录 一.适合创建索引 1.字段的数值有唯一性限制 2.频繁作为Where查询条件的字段 3.经常Groupby和Orderby的列 4.Update.Delete的where条件列 5.Distinct字段需要创建索引 6.多表Join连接操作时,创建索引注意事项 7.使用列的类型小的创建索引 8.使用字符串前缀创建索引 9.区分度高的列适合作为索引 10.使用最频繁的列放到联合索引的左侧 11.在多个字段都要创建索引的情况下,联合索引由于单值索引 二.不适合创建索引 1.在where中使用不

  • 一文弄懂MYSQL如何列转行

    目录 一.需求: 二.如何实现 1)首先看我们的静态SQL 2)那么就有人问了,如果我有100门课程不是要写100次名称,这也太麻烦了? 3)这样每次都写一长串sql也很麻烦? 总结 一.需求: 有三张表,学生表.成绩表和课程表,我们可以通过连表查询出学生姓名.课程及对应的成绩: 所需表sql -- ---------------------------- -- Table structure for student -- ---------------------------- DROP TA

  • 一文带你弄懂Java中线程池的原理

    目录 为什么要用线程池 线程池的原理 ThreadPoolExecutor提供的构造方法 ThreadPoolExecutor的策略 线程池主要的任务处理流程 ThreadPoolExecutor如何做到线程复用的 四种常见的线程池 newCachedThreadPool newFixedThreadPool newSingleThreadExecutor newScheduledThreadPool 小结 在工作中,我们经常使用线程池,但是你真的了解线程池的原理吗?同时,线程池工作原理和底层实

随机推荐