keras使用Sequence类调用大规模数据集进行训练的实现

使用Keras如果要使用大规模数据集对网络进行训练,就没办法先加载进内存再从内存直接传到显存了,除了使用Sequence类以外,还可以使用迭代器去生成数据,但迭代器无法在fit_generation里开启多进程,会影响数据的读取和预处理效率,在本文中就不在叙述了,有需要的可以另外去百度。

下面是我所使用的代码

class SequenceData(Sequence):
  def __init__(self, path, batch_size=32):
    self.path = path
    self.batch_size = batch_size
    f = open(path)
    self.datas = f.readlines()
    self.L = len(self.datas)
    self.index = random.sample(range(self.L), self.L)
  #返回长度,通过len(<你的实例>)调用
  def __len__(self):
    return self.L - self.batch_size
  #即通过索引获取a[0],a[1]这种
  def __getitem__(self, idx):
    batch_indexs = self.index[idx:(idx+self.batch_size)]
    batch_datas = [self.datas[k] for k in batch_indexs]
    img1s,img2s,audios,labels = self.data_generation(batch_datas)
    return ({'face1_input_1': img1s, 'face2_input_2': img2s, 'input_3':audios},{'activation_7':labels})

  def data_generation(self, batch_datas):
    #预处理操作
    return img1s,img2s,audios,labels

然后在代码里通过fit_generation函数调用并训练

这里要注意,use_multiprocessing参数是是否开启多进程,由于python的多线程不是真的多线程,所以多进程还是会获得比较客观的加速,但不支持windows,windows下python无法使用多进程。

D = SequenceData('train.csv')
model_train.fit_generator(generator=D,steps_per_epoch=int(len(D)),
          epochs=2, workers=20, #callbacks=[checkpoint],
          use_multiprocessing=True, validation_data=SequenceData('vali.csv'),validation_steps=int(20000/32)) 

同样的,也可以在测试的时候使用

model.evaluate_generator(generator=SequenceData('face_test.csv'),steps=int(125100/32),workers=32)

补充知识:keras数据自动生成器,继承keras.utils.Sequence,结合fit_generator实现节约内存训练

我就废话不多说了,大家还是直接看代码吧~

#coding=utf-8
'''
Created on 2018-7-10
'''
import keras
import math
import os
import cv2
import numpy as np
from keras.models import Sequential
from keras.layers import Dense

class DataGenerator(keras.utils.Sequence):

  def __init__(self, datas, batch_size=1, shuffle=True):
    self.batch_size = batch_size
    self.datas = datas
    self.indexes = np.arange(len(self.datas))
    self.shuffle = shuffle

  def __len__(self):
    #计算每一个epoch的迭代次数
    return math.ceil(len(self.datas) / float(self.batch_size))

  def __getitem__(self, index):
    #生成每个batch数据,这里就根据自己对数据的读取方式进行发挥了
    # 生成batch_size个索引
    batch_indexs = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
    # 根据索引获取datas集合中的数据
    batch_datas = [self.datas[k] for k in batch_indexs]

    # 生成数据
    X, y = self.data_generation(batch_datas)

    return X, y

  def on_epoch_end(self):
    #在每一次epoch结束是否需要进行一次随机,重新随机一下index
    if self.shuffle == True:
      np.random.shuffle(self.indexes)

  def data_generation(self, batch_datas):
    images = []
    labels = []

    # 生成数据
    for i, data in enumerate(batch_datas):
      #x_train数据
      image = cv2.imread(data)
      image = list(image)
      images.append(image)
      #y_train数据
      right = data.rfind("\\",0)
      left = data.rfind("\\",0,right)+1
      class_name = data[left:right]
      if class_name=="dog":
        labels.append([0,1])
      else:
        labels.append([1,0])
    #如果为多输出模型,Y的格式要变一下,外层list格式包裹numpy格式是list[numpy_out1,numpy_out2,numpy_out3]
    return np.array(images), np.array(labels)

# 读取样本名称,然后根据样本名称去读取数据
class_num = 0
train_datas = []
for file in os.listdir("D:/xxx"):
  file_path = os.path.join("D:/xxx", file)
  if os.path.isdir(file_path):
    class_num = class_num + 1
    for sub_file in os.listdir(file_path):
      train_datas.append(os.path.join(file_path, sub_file))

# 数据生成器
training_generator = DataGenerator(train_datas)

#构建网络
model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=784))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy',
       optimizer='sgd',
       metrics=['accuracy'])
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit_generator(training_generator, epochs=50,max_queue_size=10,workers=1)

以上这篇keras使用Sequence类调用大规模数据集进行训练的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • 浅谈keras通过model.fit_generator训练模型(节省内存)

    前言 前段时间在训练模型的时候,发现当训练集的数量过大,并且输入的图片维度过大时,很容易就超内存了,举个简单例子,如果我们有20000个样本,输入图片的维度是224x224x3,用float32存储,那么如果我们一次性将全部数据载入内存的话,总共就需要20000x224x224x3x32bit/8=11.2GB 这么大的内存,所以如果一次性要加载全部数据集的话是需要很大内存的. 如果我们直接用keras的fit函数来训练模型的话,是需要传入全部训练数据,但是好在提供了fit_generator,

  • Keras之fit_generator与train_on_batch用法

    关于Keras中,当数据比较大时,不能全部载入内存,在训练的时候就需要利用train_on_batch或fit_generator进行训练了. 两者均是利用生成器,每次载入一个batch-size的数据进行训练. 那么fit_generator与train_on_batch该用哪一个呢? train_on_batch(self, x, y, class_weight=None, sample_weight=None) fit_generator(self, generator, samples_

  • 浅谈keras2 predict和fit_generator的坑

    1.使用predict时,必须设置batch_size,否则效率奇低. 查看keras文档中,predict函数原型: predict(self, x, batch_size=32, verbose=0) 说明: 只使用batch_size=32,也就是说每次将batch_size=32的数据通过PCI总线传到GPU,然后进行预测.在一些问题中,batch_size=32明显是非常小的.而通过PCI传数据是非常耗时的. 所以,使用的时候会发现预测数据时效率奇低,其原因就是batch_size太小

  • keras中模型训练class_weight,sample_weight区别说明

    keras 中fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None) 官方文档

  • keras使用Sequence类调用大规模数据集进行训练的实现

    使用Keras如果要使用大规模数据集对网络进行训练,就没办法先加载进内存再从内存直接传到显存了,除了使用Sequence类以外,还可以使用迭代器去生成数据,但迭代器无法在fit_generation里开启多进程,会影响数据的读取和预处理效率,在本文中就不在叙述了,有需要的可以另外去百度. 下面是我所使用的代码 class SequenceData(Sequence): def __init__(self, path, batch_size=32): self.path = path self.b

  • php mailer类调用远程SMTP服务器发送邮件实现方法

    本文实例讲述了php mailer类调用远程SMTP服务器发送邮件实现方法.分享给大家供大家参考,具体如下: php mailer 是一款很好用的php电子邮件发送类模块,可以调用本地的smtp发送电子邮件,也可以调用远程的smtp发送电子邮件,但是使用时需要注意一些事项,否则就会造成发送失败,或者根本不能调用的情况,本文就我在使用这个类时,遇到的问题和解决办法进行展开,简要说明一下php mailer的用法,及注意事项. 首先下载phpmailer类库文件,在这里下载,只需一个资源分. 下载地

  • C#中派生类调用基类构造函数用法分析

    本文实例讲述了C#中派生类调用基类构造函数用法.分享给大家供大家参考.具体分析如下: 这里的默认构造函数是指在没有编写构造函数的情况下系统默认的无参构造函数 1.当基类中没有自己编写构造函数时,派生类默认的调用基类的默认构造函数 例如: public class MyBaseClass { } public class MyDerivedClass : MyBaseClass { public MyDerivedClass() { Console.WriteLine("我是子类无参构造函数&qu

  • Java如何基于ProcessBuilder类调用外部程序

    这篇文章主要介绍了Java如何基于ProcessBuilder类调用外部程序,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下 1. demo1 @Test public void testProcessBuilder() { ProcessBuilder processBuilder = new ProcessBuilder(); // processBuilder.command("ping","127.0.0.1"

  • javaweb如何使用华为云短信通知公共类调用

    javaweb华为云短信通知公共类调用 情景:公司业务需求,短信从阿里云切换到华为云,参照华为云短信调用的相关文档遇到不少坑,在此记录一下. 开发环境:JDK1.8 系统环境:SpringBoot 1.华为云短信配置信息在application.yml中配置 sms: huawei: url: https://rtcsms.cn-north-1.myhuaweicloud.com:10743/sms/batchSendSms/v1 appKey: ****** appSecret: ******

  • SpringBoot实现其他普通类调用Spring管理的Service,dao等bean

    目录 普通类调用Spring管理的Service.dao等bean 举个使用情景 下面来看我给出的解决办法 普通类中使用service.dao层中的类,只需三步 1.写一个工具类 SpringUtil 2.在Application启动类中将工具类导入 3.在ApplicationTests测试类中调用 普通类调用Spring管理的Service.dao等bean 在springboot的使用中,有时需要在其他的普通类中调用托管给spring的dao或者service,从而去操作数据库.网上大多数

  • SpringBoot 普通类调用Bean对象的一种方式推荐

    目录 SpringBoot 普通类调用Bean对象 SpringBoot 中bean的使用 SpringBoot 普通类调用Bean对象 有时我们有一些特殊的需要,可能要在一个不被Spring管理的普通类中去调用Spring管理的bean对象的一些方法,比如一般SpringMVC工程在controller中通过 @Autowired private TestService testService; 注入TestService 接口就可以调用此接口实现类的实现的方法. 但在一般类中显然不可以这么做

  • Z-Order加速Hudi大规模数据集方案分析

    目录 1. 背景 2. Z-Order介绍 3. 具体实现 3.1 z-value的生成和排序 3.1.1 基于映射策略的z值生成方法 3.1.2 基于RangeBounds的z-value生成策略 3.2 与Hudi结合 3.2.1 表数据的Z排序重组 3.2.2 收集保存统计信息 3.2.3 应用到Spark查询 4. 测试结果 1. 背景 多维分析是大数据分析的一个典型场景,这种分析一般带有过滤条件.对于此类查询,尤其是在高基字段的过滤查询,理论上只我们对原始数据做合理的布局,结合相关过滤

  • 聊聊基于pytorch实现Resnet对本地数据集的训练问题

    目录 1.dataset.py(先看代码的总体流程再看介绍) 2.network.py 3.train.py 4.结果与总结 本文是使用pycharm下的pytorch框架编写一个训练本地数据集的Resnet深度学习模型,其一共有两百行代码左右,分成mian.py.network.py.dataset.py以及train.py文件,功能是对本地的数据集进行分类.本文介绍逻辑是总分形式,即首先对总流程进行一个概括,然后分别介绍每个流程中的实现过程(代码+流程图+文字的介绍). 对于整个项目的流程首

  • darknet框架中YOLOv3对数据集进行训练和预测详解

    目录 1. 下载darknet源码 2. 修改darknet的Makefile文件 3. 准备数据集 4. 修改voc_label.py 5. 下载预训练模型 6. 修改./darknet/cfg/voc.data文件 7. 修改./darknet/data/voc.name文件 8. 修改./darknet/cfg/yolov3-voc.cfg文件 9. 开始训练 10.训练终止后继续训练方法 1. 下载darknet源码 在命令窗口(terminal)中进入你想存放darknet源码的路径,

随机推荐