Pytorch使用技巧之Dataloader中的collate_fn参数详析

以MNIST为例

from torchvision import datasets
mnist = datasets.MNIST(root='./data/', train=True, download=True)
print(mnist[0])

结果

(<PIL.Image.Image image mode=L size=28x28 at 0x196E3F1D898>, 5)

MINIST数据集的dataset是由一张图片和一个label组成的元组

dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=2, shuffle=True,collate_fn=lambda x:x)
for each in dataloader:
    print(each)
    break

结果

[(<PIL.Image.Image image mode=L size=28x28 at 0x2CB3B105630>, 0), (<PIL.Image.Image image mode=L size=28x28 at 0x2CB3B105668>, 2)]

collate_fn为lamda x:x时表示对传入进来的数据不做处理

下面自定义collate_fn看看什么效果

def collate(data):
    img = []
    label = []
    for each in data:
        img.append(each[0])
        label.append(each[1])
    return img,label
dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=2, shuffle=True,collate_fn=lambda x:collate(x))
for each in dataloader:
    print(each)
    break

结果

([<PIL.Image.Image image mode=L size=28x28 at 0x241433A36D8>, <PIL.Image.Image image mode=L size=28x28 at 0x241433A3710>], [9, 3])

说明:若不设置collate_fn参数则会使用默认处理函数

但必须保证传进来的数据都是tensor格式否则会报错

附:DataLoader完整的参数表如下:

class torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    sampler=None,
    batch_sampler=None,
    num_workers=0,
    collate_fn=<function default_collate>,
    pin_memory=False,
    drop_last=False,
    timeout=0,
    worker_init_fn=None)

DataLoader在数据集上提供单进程或多进程的迭代器

几个关键的参数意思:

- shuffle:设置为True的时候,每个世代都会打乱数据集

- collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能

- drop_last:告诉如何处理数据集长度除于batch_size余下的数据。True就抛弃,否则保留

总结

到此这篇关于Pytorch使用技巧之Dataloader中的collate_fn参数的文章就介绍到这了,更多相关Dataloader中的collate_fn参数内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • Pytorch技巧:DataLoader的collate_fn参数使用详解

    DataLoader完整的参数表如下: class torch.utils.data.DataLoader( dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None) D

  • Pytorch使用技巧之Dataloader中的collate_fn参数详析

    以MNIST为例 from torchvision import datasets mnist = datasets.MNIST(root='./data/', train=True, download=True) print(mnist[0]) 结果 (<PIL.Image.Image image mode=L size=28x28 at 0x196E3F1D898>, 5) MINIST数据集的dataset是由一张图片和一个label组成的元组 dataloader = torch.ut

  • Jquery中$.ajax()方法参数详解

    俗说好记性不如个烂笔头,下面是jquery中的ajax方法参数详解,这里整理了一些供大家参考. 1.url: 要求为String类型的参数,(默认为当前页地址)发送请求的地址. 2.type: 要求为String类型的参数,请求方式(post或get)默认为get.注意其他http请求方法,例如put和delete也可以使用,但仅部分浏览器支持. 3.timeout: 要求为Number类型的参数,设置请求超时时间(毫秒).此设置将覆盖$.ajaxSetup()方法的全局设置. 4.async:

  • Webpack中SplitChunksPlugin 配置参数详解

    代码分割本身和 webpack 没有什么关系,但是由于使用 webpack 可以非常轻松地实现代码分割,所以提到代码分割首先就会想到使用 webopack 实现. 在 webpack 中是使用 SplitChunksPlugin 来实现的,由于 SplitChunksPlugin 配置参数众多,接下来就来梳理一下这些配置参数. 官网上的默认配置参数如下: module.exports = { //... optimization: { splitChunks: { chunks: 'async'

  • OpenCV中findContours函数参数详解

    注: 这篇文章用的OpenCV版本是2.4.10, 3以上的OpenCV版本相关函数可能有改动 Opencv中通过使用findContours函数,简单几个的步骤就可以检测出物体的轮廓,很方便.这些准备继续探讨一下findContours方法中各参数的含义及用法,比如要求只检测最外层轮廓该怎么办?contours里边的数据结构是怎样的?hierarchy到底是什么鬼?Point()有什么用? 先从findContours函数原型看起: findContours( InputOutputArray

  • Python中的默认参数详解

    文章的主题 不要使用可变对象作为函数的默认参数例如 list,dict,因为def是一个可执行语句,只有def执行的时候才会计算默认默认参数的值,所以使用默认参数会造成函数执行的时候一直在使用同一个对象,引起bug. 基本原理 在 Python 源码中,我们使用def来定义函数或者方法.在其他语言中,类似的东西往往只是一一个语法声明关键字,但def却是一个可执行的指令.Python代码执行的时候先会使用 compile 将其编译成 PyCodeObject. PyCodeObject 本质上依然

  • Spring Boot 中starter的原理详析

    目录 1.springboot 的starter 的启动原理是什么 原理 来个例子 小结 2.springboot 是如何找到配置类的 3.springboot starter 的bean 是怎么加载到容器的 4.总结 前言: 今天介绍springboot ,也是写下springboot的插件机制,starter的原理,其实这个网上已经很多了,也是看了不少别人的文章,今天主要还是带着问题去记录下. 1.springboot 的starter 的启动原理是什么 原理 这个问题是很简单的,只要了解s

  • SQL注入技巧之显注与盲注中过滤逗号绕过详析

    前言 sql注入在很早很早以前是很常见的一个漏洞.后来随着安全水平的提高,sql注入已经很少能够看到了.但是就在今天,还有很多网站带着sql注入漏洞在运行.下面这篇文章主要介绍了关于SQL注入逗号绕过的相关内容,分享出来供大家参考学习,下面话不多说了,来一起看看详细的介绍吧 1.联合查询显注绕过逗号 在联合查询时使用 UNION SELECT 1,2,3,4,5,6,7..n 这样的格式爆显示位,语句中包含了多个逗号,如果有WAF拦截了逗号时,我们的联合查询不能用了. 绕过 在显示位上替换为常见

  • Java中支持可变参数详解

    意思就是:参数的个数可以根据需要写,你可以写1个.2个.3个....他们都被保存到一个参数的数组中. 但是这些参有一些约束:他们必须是同类型的,比如都是String字符串类型. 同时,可变参数的函数中的参数的写法也有约束:比如,可变参数的数组必须写在参数的最后,否则程序不知道你的参数到底有多少个. 例子:输出可变参数中的参数值 public class VariableArgument { public static void main(String[] args) { printArgumen

  • JQuery中$.ajax()方法参数详解及应用

    url: 要求为String类型的参数,(默认为当前页地址)发送请求的地址. type: 要求为String类型的参数,请求方式(post或get)默认为get.注意其他http请求方法,例如put和 delete也可以使用,但仅部分浏览器支持. timeout: 要求为Number类型的参数,设置请求超时时间(毫秒).此设置将覆盖$.ajaxSetup()方法的全局设 置. async:要求为Boolean类型的参数,默认设置为true,所有请求均为异步请求. 如果需要发送同步请求,请将此选项

  • Flash中 MC事件参数详解

    MC的事件说明如下:  Load:当MC 被装载进来的时候发生此事件  EnterFrame:当MC被卸载的时候发生此事件  MouseDown:当鼠标左键被按下时发生此事件  MouseUp:当鼠标左键被松开时发生此事件  MouseMove:当鼠标移动的时候发生此事件  Keydown:当键盘按钮被按下时发生此事件  Keyup:当键盘按钮被松开时发生此事件  Data:使用loadVariable或者loadMovie动作,当接受到数据时,发生此事件 MovieClip (对象)  MC对

随机推荐