基于keras中训练数据的几种方式对比(fit和fit_generator)

一、train_on_batch

model.train_on_batch(batchX, batchY)

train_on_batch函数接受单批数据,执行反向传播,然后更新模型参数,该批数据的大小可以是任意的,即,它不需要提供明确的批量大小,属于精细化控制训练模型,大部分情况下我们不需要这么精细,99%情况下使用fit_generator训练方式即可,下面会介绍。

二、fit

model.fit(x_train, y_train, batch_size=32, epochs=10)

fit的方式是一次把训练数据全部加载到内存中,然后每次批处理batch_size个数据来更新模型参数,epochs就不用多介绍了。这种训练方式只适合训练数据量比较小的情况下使用。

三、fit_generator

利用Python的生成器,逐个生成数据的batch并进行训练,不占用大量内存,同时生成器与模型将并行执行以提高效率。例如,该函数允许我们在CPU上进行实时的数据提升,同时在GPU上进行模型训练

接口如下:

fit_generator(self, generator, steps_per_epoch, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_q_size=10, workers=1, pickle_safe=False, initial_epoch=0)

generator:生成器函数

steps_per_epoch:整数,当生成器返回steps_per_epoch次数据时,计一个epoch结束,执行下一个epoch。也就是一个epoch下执行多少次batch_size。

epochs:整数,控制数据迭代的轮数,到了就结束训练。

callbacks=None, list,list中的元素为keras.callbacks.Callback对象,在训练过程中会调用list中的回调函数

举例:

def generate_arrays_from_file(path):
            while True:
                with open(path) as f:
                    for line in f:
                        # create numpy arrays of input data
                        # and labels, from each line in the file
                        x1, x2, y = process_line(line)
                        yield ({'input_1': x1, 'input_2': x2}, {'output': y})

model.fit_generator(generate_arrays_from_file('./my_folder'),
                            steps_per_epoch=10000, epochs=10)

补充:keras.fit_generator()属性及取值

如下所示:

fit_generator(self, generator,
                    steps_per_epoch=None,
                    epochs=1,
                    verbose=1,
                    callbacks=None,
                    validation_data=None,
                    validation_steps=None,
                    class_weight=None,
                    max_queue_size=10,
                    workers=1,
                    use_multiprocessing=False,
                    shuffle=True,
                    initial_epoch=0)

通过Python generator产生一批批的数据用于训练模型。generator可以和模型并行运行,例如,可以使用CPU生成批数据同时在GPU上训练模型。

参数:

generator:一个generator或Sequence实例,为了避免在使用multiprocessing时直接复制数据。

steps_per_epoch:从generator产生的步骤的总数(样本批次总数)。通常情况下,应该等于数据集的样本数量除以批量的大小。

epochs:整数,在数据集上迭代的总数。

works:在使用基于进程的线程时,最多需要启动的进程数量。

use_multiprocessing:布尔值。当为True时,使用基于基于过程的线程。

例如:

datagen = ImageDataGenator(...)
model.fit_generator(datagen.flow(x_train, y_train,
                                 batch_size=batch_size),
                    epochs=epochs,
                    validation_data=(x_test, y_test),
                    workers=4)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • 浅谈keras通过model.fit_generator训练模型(节省内存)

    前言 前段时间在训练模型的时候,发现当训练集的数量过大,并且输入的图片维度过大时,很容易就超内存了,举个简单例子,如果我们有20000个样本,输入图片的维度是224x224x3,用float32存储,那么如果我们一次性将全部数据载入内存的话,总共就需要20000x224x224x3x32bit/8=11.2GB 这么大的内存,所以如果一次性要加载全部数据集的话是需要很大内存的. 如果我们直接用keras的fit函数来训练模型的话,是需要传入全部训练数据,但是好在提供了fit_generator,

  • 在keras中model.fit_generator()和model.fit()的区别说明

    首先Keras中的fit()函数传入的x_train和y_train是被完整的加载进内存的,当然用起来很方便,但是如果我们数据量很大,那么是不可能将所有数据载入内存的,必将导致内存泄漏,这时候我们可以用fit_generator函数来进行训练. keras中文文档 fit fit(x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.0, validation_data=N

  • 浅谈keras2 predict和fit_generator的坑

    1.使用predict时,必须设置batch_size,否则效率奇低. 查看keras文档中,predict函数原型: predict(self, x, batch_size=32, verbose=0) 说明: 只使用batch_size=32,也就是说每次将batch_size=32的数据通过PCI总线传到GPU,然后进行预测.在一些问题中,batch_size=32明显是非常小的.而通过PCI传数据是非常耗时的. 所以,使用的时候会发现预测数据时效率奇低,其原因就是batch_size太小

  • keras 两种训练模型方式详解fit和fit_generator(节省内存)

    第一种,fit import keras from keras.models import Sequential from keras.layers import Dense import numpy as np from sklearn.preprocessing import LabelEncoder from sklearn.preprocessing import OneHotEncoder from sklearn.model_selection import train_test_s

  • 基于keras中训练数据的几种方式对比(fit和fit_generator)

    一.train_on_batch model.train_on_batch(batchX, batchY) train_on_batch函数接受单批数据,执行反向传播,然后更新模型参数,该批数据的大小可以是任意的,即,它不需要提供明确的批量大小,属于精细化控制训练模型,大部分情况下我们不需要这么精细,99%情况下使用fit_generator训练方式即可,下面会介绍. 二.fit model.fit(x_train, y_train, batch_size=32, epochs=10) fit的

  • 基于TensorFlow中自定义梯度的2种方式

    前言 在深度学习中,有时候我们需要对某些节点的梯度进行一些定制,特别是该节点操作不可导(比如阶梯除法如 ),如果实在需要对这个节点进行操作,而且希望其可以反向传播,那么就需要对其进行自定义反向传播时的梯度.在有些场景,如[2]中介绍到的梯度反转(gradient inverse)中,就必须在某层节点对反向传播的梯度进行反转,也就是需要更改正常的梯度传播过程,如下图的 所示. 在tensorflow中有若干可以实现定制梯度的方法,这里介绍两种. 1. 重写梯度法 重写梯度法指的是通过tensorf

  • Java中获取时间戳的三种方式对比实现

    Java中获取时间戳 三种方式对比 最近项目开发过程中发现了项目中获取时间戳的业务.而获取时间戳有以下三种方式,首先先声明推荐使用System类来获取时间戳,下面一起看一看三种方式. 1.System.currentTimeMillis() System类中的currentTimeMillis()方法是三种方式中效率最好的,运行时间最短.开发中如果设计到效率问题,推荐使用此种方式获取. System.currentTimeMillis() 2.new Date().getTime() 除了Sys

  • vue-cli中设置publicPath的几种方式对比

    目录 设置publicPath的几种方式对比 publicPath打包设置 vue.config.js publicPath "./" npm run build无效 设置publicPath的几种方式对比 publicPath打包设置 1. 不设置(默认为 publicPath: ‘/’) 或者设置 publicPath: '/' // vue.config.js module.exports = {   publicPath: '/', } html中被打包的css和js路径如下

  • MongoDB中优雅删除大量数据的三种方式

    目录 为什么要"瘦身"呢? MongoDB中删除数据的三种方式 三种方式的执行效率对比 1. remove 2. deleteMany 3. bulkWrite 通过 Write Concern 规避主从延迟 删除过程中遇到的Bug 总结 删除大量数据,无论是在哪种数据库中,都是一个普遍性的需求.除了正常的业务需求,我们需要通过这种方式来为数据库"瘦身". 为什么要"瘦身"呢? 1.表的数据量到达一定量级后,数据量越大,表的查询性能会越差. 毕竟

  • mysql清空表数据的两种方式和区别解析

    在MySQL中删除数据有两种方式: truncate(截短)属于粗暴型的清空 delete属于精细化的删除 删除操作 如果你需要清空表里的所有数据,下面两种均可: delete from tablename; truncate table tablename; 而如果你只是删除一部分数据,就只能使用delete: delete from tablename where case1 and case2; 区别 在精细化的删除部分数据时,只能使用delete. 而清空所有表数据时,两者均可,此时这两

  • 基于keras输出中间层结果的2种实现方式

    1.使用函数模型API,新建一个model,将输入和输出定义为原来的model的输入和想要的那一层的输出,然后重新进行predict. #coding=utf-8 import seaborn as sbn import pylab as plt import theano from keras.models import Sequential from keras.layers import Dense,Activation from keras.models import Model mod

  • 基于keras中的回调函数用法说明

    keras训练 fit( self, x, y, batch_size=32, nb_epoch=10, verbose=1, callbacks=[], validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None ) 1. x:输入数据.如果模型只有一个输入,那么x的类型是numpy array,如果模型有多个输入,那么x的类型应当为list,list的元素是对应

  • 基于$.ajax()方法从服务器获取json数据的几种方式总结

    一.什么是json json是一种取代xml的数据结构,和xml相比,它更小巧但描述能力却很强,网络传输数据使用流量更少,速度更快. json就是一串字符串,使用下面的符号标注. {键值对} : json对象 [{},{},{}] :json数组 "" :双引号内是属性或值 : :冒号前为键,后为值(这个值可以是基本数据类型的值,也可以是数组或对象),所以 {"age": 18} 可以理解为是一个包含age为18的json对象,而[{"age":

  • 基于Keras中Conv1D和Conv2D的区别说明

    如有错误,欢迎斧正. 我的答案是,在Conv2D输入通道为1的情况下,二者是没有区别或者说是可以相互转化的.首先,二者调用的最后的代码都是后端代码(以TensorFlow为例,在tensorflow_backend.py里面可以找到): x = tf.nn.convolution( input=x, filter=kernel, dilation_rate=(dilation_rate,), strides=(strides,), padding=padding, data_format=tf_

随机推荐