Pytorch DataLoader 变长数据处理方式

关于Pytorch中怎么自定义Dataset数据集类、怎样使用DataLoader迭代加载数据,这篇官方文档已经说得很清楚了,这里就不在赘述。

现在的问题:有的时候,特别对于NLP任务来说,输入的数据可能不是定长的,比如多个句子的长度一般不会一致,这时候使用DataLoader加载数据时,不定长的句子会被胡乱切分,这肯定是不行的。

解决方法是重写DataLoader的collate_fn,具体方法如下:

# 假如每一个样本为:
sample = {
	# 一个句子中各个词的id
	'token_list' : [5, 2, 4, 1, 9, 8],
	# 结果y
	'label' : 5,
}

# 重写collate_fn函数,其输入为一个batch的sample数据
def collate_fn(batch):
	# 因为token_list是一个变长的数据,所以需要用一个list来装这个batch的token_list
  token_lists = [item['token_list'] for item in batch]

  # 每个label是一个int,我们把这个batch中的label也全取出来,重新组装
  labels = [item['label'] for item in batch]
  # 把labels转换成Tensor
  labels = torch.Tensor(labels)
  return {
    'token_list': token_lists,
    'label': labels,
  }

# 在使用DataLoader加载数据时,注意collate_fn参数传入的是重写的函数
DataLoader(trainset, batch_size=4, shuffle=True, num_workers=4, collate_fn=collate_fn)

使用以上方法,可以保证DataLoader能Load出一个batch的数据,load出来的东西就是重写的collate_fn函数最后return出来的字典。

以上这篇Pytorch DataLoader 变长数据处理方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • Pytorch 数据加载与数据预处理方式

    数据加载分为加载torchvision.datasets中的数据集以及加载自己使用的数据集两种情况. torchvision.datasets中的数据集 torchvision.datasets中自带MNIST,Imagenet-12,CIFAR等数据集,所有的数据集都是torch.utils.data.Dataset的子类,都包含 _ _ len _ (获取数据集长度)和 _ getItem _ _ (获取数据集中每一项)两个子方法. Dataset源码如上,可以看到其中包含了两个没有实现的子

  • pytorch中的自定义数据处理详解

    pytorch在数据中采用Dataset的数据保存方式,需要继承data.Dataset类,如果需要自己处理数据的话,需要实现两个基本方法. :.getitem:返回一条数据或者一个样本,obj[index] = obj.getitem(index). :.len:返回样本的数量 . len(obj) = obj.len(). Dataset 在data里,调用的时候使用 from torch.utils import data import os from PIL import Image 数

  • pytorch中如何使用DataLoader对数据集进行批处理的方法

    最近搞了搞minist手写数据集的神经网络搭建,一个数据集里面很多个数据,不能一次喂入,所以需要分成一小块一小块喂入搭建好的网络. pytorch中有很方便的dataloader函数来方便我们进行批处理,做了简单的例子,过程很简单,就像把大象装进冰箱里一共需要几步? 第一步:打开冰箱门. 我们要创建torch能够识别的数据集类型(pytorch中也有很多现成的数据集类型,以后再说). 首先我们建立两个向量X和Y,一个作为输入的数据,一个作为正确的结果: 随后我们需要把X和Y组成一个完整的数据集,

  • Pytorch DataLoader 变长数据处理方式

    关于Pytorch中怎么自定义Dataset数据集类.怎样使用DataLoader迭代加载数据,这篇官方文档已经说得很清楚了,这里就不在赘述. 现在的问题:有的时候,特别对于NLP任务来说,输入的数据可能不是定长的,比如多个句子的长度一般不会一致,这时候使用DataLoader加载数据时,不定长的句子会被胡乱切分,这肯定是不行的. 解决方法是重写DataLoader的collate_fn,具体方法如下: # 假如每一个样本为: sample = { # 一个句子中各个词的id 'token_li

  • 解决pytorch rnn 变长输入序列的问题

    pytorch实现变长输入的rnn分类 输入数据是长度不固定的序列数据,主要讲解两个部分 1.Data.DataLoader的collate_fn用法,以及按batch进行padding数据 2.pack_padded_sequence和pad_packed_sequence来处理变长序列 collate_fn Dataloader的collate_fn参数,定义数据处理和合并成batch的方式. 由于pack_padded_sequence用到的tensor必须按照长度从大到小排过序的,所以在

  • pytorch dataloader 取batch_size时候出现bug的解决方式

    1. RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 342 and 281 in dimension 3 at /pytorch/aten/src/TH/generic/THTensorMoreMath.cpp:1333 2. RuntimeError: invalid argument 0: Sizes of tensors must match except i

  • Pytorch DataLoader shuffle验证方式

    shuffle = False时,不打乱数据顺序 shuffle = True,随机打乱 import numpy as np import h5py import torch from torch.utils.data import DataLoader, Dataset h5f = h5py.File('train.h5', 'w'); data1 = np.array([[1,2,3], [2,5,6], [3,5,6], [4,5,6]]) data2 = np.array([[1,1,

  • Python函数中*args和**kwargs来传递变长参数的用法

    单星号形式(*args)用来传递非命名键可变参数列表.双星号形式(**kwargs)用来传递键值可变参数列表. 下面的例子,传递了一个固定位置参数和两个变长参数. def test_var_args(farg, *args): print "formal arg:", farg for arg in args: print "another arg:", arg test_var_args(1, "two", 3) 结果如下: formal ar

  • C++中的变长参数深入理解

    前言 在吸进的一个项目中为了使用共享内存和自定义内存池,我们自己定义了MemNew函数,且在函数内部对于非pod类型自动执行构造函数.在需要的地方调用自定义的MemNew函数.这样就带来一个问题,使用stl的类都有默认构造函数,以及复制构造函数等.但使用共享内存和内存池的类可能没有默认构造函数,而是定义了多个参数的构造函数,于是如何将参数传入MemNew函数便成了问题. 一.变长参数函数 首先回顾一下较多使用的变长参数函数,最经典的便是printf. extern int printf(cons

  • 浅谈C++内存分配及变长数组的动态分配

    第一部分 C++内存分配 一.关于内存 1.内存分配方式 内存分配方式有三种: (1)从静态存储区域分配.内存在程序编译的时候就已经分配好,这块内存在程序的整个运行期间都存在 例如全局变量,static变量. (2)在栈上创建.在执行函数时,函数内局部变量的存储单元都可以在栈上创建,函数执行结束时这些存 储单元自动被释放.栈内存分配运算内置于处理器的指令集中,效率很高,但是分配的内存容量有限. (3) 从堆上分配,亦称动态内存分配.程序在运行的时候用malloc或new申请任意多少的内存,程序员

  • Pytorch实现神经网络的分类方式

    本文用于利用Pytorch实现神经网络的分类!!! 1.训练神经网络分类模型 import torch from torch.autograd import Variable import matplotlib.pyplot as plt import torch.nn.functional as F import torch.utils.data as Data torch.manual_seed(1)#设置随机种子,使得每次生成的随机数是确定的 BATCH_SIZE = 5#设置batch

  • 解决pytorch DataLoader num_workers出现的问题

    最近在学pytorch,在使用数据分批训练时在导入数据是使用了 DataLoader 在参数 num_workers的设置上使程序出现运行没有任何响应的结果 ,看看代码 import torch #导入模块 import torch.utils.data as Data BATCH_SIZE=8 #每一批的数据量 x=torch.linspace(1,10,10) #定义X为 1 到 10 等距离大小的数 y=torch.linspace(10,1,10) #转换成torch能识别的Datase

随机推荐