Pytorch DataLoader 变长数据处理方式
关于Pytorch中怎么自定义Dataset数据集类、怎样使用DataLoader迭代加载数据,这篇官方文档已经说得很清楚了,这里就不在赘述。
现在的问题:有的时候,特别对于NLP任务来说,输入的数据可能不是定长的,比如多个句子的长度一般不会一致,这时候使用DataLoader加载数据时,不定长的句子会被胡乱切分,这肯定是不行的。
解决方法是重写DataLoader的collate_fn,具体方法如下:
# 假如每一个样本为: sample = { # 一个句子中各个词的id 'token_list' : [5, 2, 4, 1, 9, 8], # 结果y 'label' : 5, } # 重写collate_fn函数,其输入为一个batch的sample数据 def collate_fn(batch): # 因为token_list是一个变长的数据,所以需要用一个list来装这个batch的token_list token_lists = [item['token_list'] for item in batch] # 每个label是一个int,我们把这个batch中的label也全取出来,重新组装 labels = [item['label'] for item in batch] # 把labels转换成Tensor labels = torch.Tensor(labels) return { 'token_list': token_lists, 'label': labels, } # 在使用DataLoader加载数据时,注意collate_fn参数传入的是重写的函数 DataLoader(trainset, batch_size=4, shuffle=True, num_workers=4, collate_fn=collate_fn)
使用以上方法,可以保证DataLoader能Load出一个batch的数据,load出来的东西就是重写的collate_fn函数最后return出来的字典。
以上这篇Pytorch DataLoader 变长数据处理方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。
相关推荐
-
pytorch中如何使用DataLoader对数据集进行批处理的方法
最近搞了搞minist手写数据集的神经网络搭建,一个数据集里面很多个数据,不能一次喂入,所以需要分成一小块一小块喂入搭建好的网络. pytorch中有很方便的dataloader函数来方便我们进行批处理,做了简单的例子,过程很简单,就像把大象装进冰箱里一共需要几步? 第一步:打开冰箱门. 我们要创建torch能够识别的数据集类型(pytorch中也有很多现成的数据集类型,以后再说). 首先我们建立两个向量X和Y,一个作为输入的数据,一个作为正确的结果: 随后我们需要把X和Y组成一个完整的数据集,
-
Pytorch 数据加载与数据预处理方式
数据加载分为加载torchvision.datasets中的数据集以及加载自己使用的数据集两种情况. torchvision.datasets中的数据集 torchvision.datasets中自带MNIST,Imagenet-12,CIFAR等数据集,所有的数据集都是torch.utils.data.Dataset的子类,都包含 _ _ len _ (获取数据集长度)和 _ getItem _ _ (获取数据集中每一项)两个子方法. Dataset源码如上,可以看到其中包含了两个没有实现的子
-
pytorch中的自定义数据处理详解
pytorch在数据中采用Dataset的数据保存方式,需要继承data.Dataset类,如果需要自己处理数据的话,需要实现两个基本方法. :.getitem:返回一条数据或者一个样本,obj[index] = obj.getitem(index). :.len:返回样本的数量 . len(obj) = obj.len(). Dataset 在data里,调用的时候使用 from torch.utils import data import os from PIL import Image 数
-
Pytorch DataLoader 变长数据处理方式
关于Pytorch中怎么自定义Dataset数据集类.怎样使用DataLoader迭代加载数据,这篇官方文档已经说得很清楚了,这里就不在赘述. 现在的问题:有的时候,特别对于NLP任务来说,输入的数据可能不是定长的,比如多个句子的长度一般不会一致,这时候使用DataLoader加载数据时,不定长的句子会被胡乱切分,这肯定是不行的. 解决方法是重写DataLoader的collate_fn,具体方法如下: # 假如每一个样本为: sample = { # 一个句子中各个词的id 'token_li
-
解决pytorch rnn 变长输入序列的问题
pytorch实现变长输入的rnn分类 输入数据是长度不固定的序列数据,主要讲解两个部分 1.Data.DataLoader的collate_fn用法,以及按batch进行padding数据 2.pack_padded_sequence和pad_packed_sequence来处理变长序列 collate_fn Dataloader的collate_fn参数,定义数据处理和合并成batch的方式. 由于pack_padded_sequence用到的tensor必须按照长度从大到小排过序的,所以在
-
pytorch dataloader 取batch_size时候出现bug的解决方式
1. RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 342 and 281 in dimension 3 at /pytorch/aten/src/TH/generic/THTensorMoreMath.cpp:1333 2. RuntimeError: invalid argument 0: Sizes of tensors must match except i
-
Pytorch DataLoader shuffle验证方式
shuffle = False时,不打乱数据顺序 shuffle = True,随机打乱 import numpy as np import h5py import torch from torch.utils.data import DataLoader, Dataset h5f = h5py.File('train.h5', 'w'); data1 = np.array([[1,2,3], [2,5,6], [3,5,6], [4,5,6]]) data2 = np.array([[1,1,
-
Python函数中*args和**kwargs来传递变长参数的用法
单星号形式(*args)用来传递非命名键可变参数列表.双星号形式(**kwargs)用来传递键值可变参数列表. 下面的例子,传递了一个固定位置参数和两个变长参数. def test_var_args(farg, *args): print "formal arg:", farg for arg in args: print "another arg:", arg test_var_args(1, "two", 3) 结果如下: formal ar
-
C++中的变长参数深入理解
前言 在吸进的一个项目中为了使用共享内存和自定义内存池,我们自己定义了MemNew函数,且在函数内部对于非pod类型自动执行构造函数.在需要的地方调用自定义的MemNew函数.这样就带来一个问题,使用stl的类都有默认构造函数,以及复制构造函数等.但使用共享内存和内存池的类可能没有默认构造函数,而是定义了多个参数的构造函数,于是如何将参数传入MemNew函数便成了问题. 一.变长参数函数 首先回顾一下较多使用的变长参数函数,最经典的便是printf. extern int printf(cons
-
浅谈C++内存分配及变长数组的动态分配
第一部分 C++内存分配 一.关于内存 1.内存分配方式 内存分配方式有三种: (1)从静态存储区域分配.内存在程序编译的时候就已经分配好,这块内存在程序的整个运行期间都存在 例如全局变量,static变量. (2)在栈上创建.在执行函数时,函数内局部变量的存储单元都可以在栈上创建,函数执行结束时这些存 储单元自动被释放.栈内存分配运算内置于处理器的指令集中,效率很高,但是分配的内存容量有限. (3) 从堆上分配,亦称动态内存分配.程序在运行的时候用malloc或new申请任意多少的内存,程序员
-
Pytorch实现神经网络的分类方式
本文用于利用Pytorch实现神经网络的分类!!! 1.训练神经网络分类模型 import torch from torch.autograd import Variable import matplotlib.pyplot as plt import torch.nn.functional as F import torch.utils.data as Data torch.manual_seed(1)#设置随机种子,使得每次生成的随机数是确定的 BATCH_SIZE = 5#设置batch
-
解决pytorch DataLoader num_workers出现的问题
最近在学pytorch,在使用数据分批训练时在导入数据是使用了 DataLoader 在参数 num_workers的设置上使程序出现运行没有任何响应的结果 ,看看代码 import torch #导入模块 import torch.utils.data as Data BATCH_SIZE=8 #每一批的数据量 x=torch.linspace(1,10,10) #定义X为 1 到 10 等距离大小的数 y=torch.linspace(10,1,10) #转换成torch能识别的Datase
随机推荐
- node.js中的fs.ftruncate方法使用说明
- 15分钟提醒一次,珍惜时间啊
- .NET程序集引用COM组件MSScriptControl遇到问题的解决方法
- 用bat和 reg实现关闭局域网共享
- Java中使用JavaMail多发邮件及邮件的验证和附件实现
- win2003服务器.NET+IIS环境常见问题排障总结
- JS实现鼠标经过好友列表中的好友头像时显示资料卡的效果
- 将json对象转换为字符串的方法
- Zend Framework入门教程之Zend_Db数据库操作详解
- 使用mysqldump实现mysql备份
- Git 命令详解及常用命令整理
- Java中关于int和Integer的区别详解
- jQuery fancybox在ie浏览器下无法显示关闭按钮的解决办法
- JavaScript 原型学习总结
- Java配置JDK开发环境及环境变量
- c#消息提示框messagebox的详解及使用
- 发展海外业务 海外邮件重点出击
- Spring Data MongoDB中实现自定义级联的方法详解
- nodejs简单读写excel内容的方法示例
- PHP后台备份MySQL数据库的源码实例