浅谈keras2 predict和fit_generator的坑

1、使用predict时,必须设置batch_size,否则效率奇低。

查看keras文档中,predict函数原型:

predict(self, x, batch_size=32, verbose=0)

说明:

只使用batch_size=32,也就是说每次将batch_size=32的数据通过PCI总线传到GPU,然后进行预测。在一些问题中,batch_size=32明显是非常小的。而通过PCI传数据是非常耗时的。

所以,使用的时候会发现预测数据时效率奇低,其原因就是batch_size太小了。

经验:

使用predict时,必须人为设置好batch_size,否则PCI总线之间的数据传输次数过多,性能会非常低下。

2、fit_generator

说明:keras 中 fit_generator参数steps_per_epoch已经改变含义了,目前的含义是一个epoch分成多少个batch_size。旧版的含义是一个epoch的样本数目。

如果说训练样本树N=1000,steps_per_epoch = 10,那么相当于一个batch_size=100,如果还是按照旧版来设置,那么相当于

batch_size = 1,会性能非常低。

经验:

必须明确fit_generator参数steps_per_epoch

补充知识:Keras:创建自己的generator(适用于model.fit_generator),解决内存问题

为什么要使用model.fit_generator?

在现实的机器学习中,训练一个model往往需要数量巨大的数据,如果使用fit进行数据训练,很有可能导致内存不够,无法进行训练。

fit_generator的定义如下:

fit_generator(generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)

其中各项的具体解释,请参考Keras中文文档

我们重点关注的是generator参数:

generator: 一个生成器,或者一个 Sequence (keras.utils.Sequence) 对象的实例, 以在使用多进程时避免数据的重复。 生成器的输出应该为以下之一:

一个 (inputs, targets) 元组

一个 (inputs, targets, sample_weights) 元组。

那么,问题来了,如何构建这个generator呢?有以下几种办法:

自己创建一个generator生成器

自己定义一个 Sequence (keras.utils.Sequence) 对象

使用Keras自带的ImageDataGenerator和.flow/.flow_from_dataframe/.flow_from_directory来生成一个generator

1.自己创建一个generator生成器

使用Keras自带的ImageDataGenerator和.flow/.flow_from_dataframe/.flow_from_directory 灵活度不高,只有当数据集满足一定格式(例如,按照分类文件夹存放)或者具备一定条件时,使用才使用才较为方便。

此时,自己创建一个generator就很重要了,关于python的generator是什么原理,怎么使用,就不加赘述,可以查看python的基本语法。

此处,我们用yield来返回数据组,标签组,从而使fit_generator可以调用我们的generator来成批处理数据。

具体实现如下:

  def myGenerator(batch_size):
    # loading data
    X_train,Y_train=load_data(...)

    # data processing
    # ................

    total_size=X_train.size
    #batch_size means how many data you want to train one step

    while 1:
      for i in range(total_size//batch_size):
        yield x_train[i*batch_size:(i+1)*batch_size], y[i*batch_size:(i+1)*batch_size]
  return myGenerator

接着你可以调用该生成器:

self._model.fit_generator(myGenerator(batch_size),steps_per_epoch=total_size//batch_size, epochs=epoch_num)

以上这篇浅谈keras2 predict和fit_generator的坑就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • keras 如何保存最佳的训练模型

    1.只保存最佳的训练模型 2.保存有所有有提升的模型 3.加载模型 4.参数说明 只保存最佳的训练模型 from keras.callbacks import ModelCheckpoint filepath='weights.best.hdf5' # 有一次提升, 则覆盖一次. checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1,save_best_only=True,mode='max',period=2)

  • keras-siamese用自己的数据集实现详解

    Siamese网络不做过多介绍,思想并不难,输入两个图像,输出这两张图像的相似度,两个输入的网络结构是相同的,参数共享. 主要发现很多代码都是基于mnist数据集的,下面说一下怎么用自己的数据集实现siamese网络. 首先,先整理数据集,相同的类放到同一个文件夹下,如下图所示: 接下来,将pairs及对应的label写到csv中,代码如下: import os import random import csv #图片所在的路径 path = '/Users/mac/Desktop/wxd/fl

  • Kears 使用:通过回调函数保存最佳准确率下的模型操作

    1:首先,我给我的MixTest文件夹里面分好了类的图片进行重命名(因为分类的时候没有注意导致命名有点不好) def load_data(path): Rename the picture [a tool] for eachone in os.listdir(path): newname = eachone[7:] os.rename(path+"\\"+eachone,path+"\\"+newname) 但是需要注意的是:我们按照类重命名了以后,系统其实会按照图

  • keras多显卡训练方式

    使用keras进行训练,默认使用单显卡,即使设置了os.environ['CUDA_VISIBLE_DEVICES']为两张显卡,也只是占满了显存,再设置tf.GPUOptions(allow_growth=True)之后可以清楚看到,只占用了第一张显卡,第二张显卡完全没用. 要使用多张显卡,需要按如下步骤: (1)import multi_gpu_model函数:from keras.utils import multi_gpu_model (2)在定义好model之后,使用multi_gpu

  • 浅谈keras2 predict和fit_generator的坑

    1.使用predict时,必须设置batch_size,否则效率奇低. 查看keras文档中,predict函数原型: predict(self, x, batch_size=32, verbose=0) 说明: 只使用batch_size=32,也就是说每次将batch_size=32的数据通过PCI总线传到GPU,然后进行预测.在一些问题中,batch_size=32明显是非常小的.而通过PCI传数据是非常耗时的. 所以,使用的时候会发现预测数据时效率奇低,其原因就是batch_size太小

  • 浅谈angularjs module返回对象的坑(推荐)

    通过将module中不同的部件拆分到不同的js文件中,在组装的时候发现module存在一个奇怪的问题,不知道是不是AngularJS的bug.至今没有找到解释. 问题是这样的,按照理解,angular.module('app.main', []);这样一句话相当于从app.main命名空间返回出一个app对象.那么,不论在任何js文件中,我通过该方法获得的app变量所储存的指针都应该指向唯一的一个堆内存,而这个内存中存储的就是这个app对象.这种操作在module的js文件,和controlle

  • 浅谈React Native打包apk的坑

    RN的打包,大家可以根据官网一步一步来,但这里有几个地方注意,一下简单介绍: 生成一个签名密钥 在项目的目录下打开cmd命令窗口输入一下命令运行: keytool -genkey -v -keystore my-release-key.keystore -alias my-key-alias -keyalg RSA -keysize 2048 -validity 10000 这条命令会要求你输入密钥库(keystore)和对应密钥的密码,然后设置一些发行相关的信息.最后它会生成一个叫做my-re

  • 浅谈keras通过model.fit_generator训练模型(节省内存)

    前言 前段时间在训练模型的时候,发现当训练集的数量过大,并且输入的图片维度过大时,很容易就超内存了,举个简单例子,如果我们有20000个样本,输入图片的维度是224x224x3,用float32存储,那么如果我们一次性将全部数据载入内存的话,总共就需要20000x224x224x3x32bit/8=11.2GB 这么大的内存,所以如果一次性要加载全部数据集的话是需要很大内存的. 如果我们直接用keras的fit函数来训练模型的话,是需要传入全部训练数据,但是好在提供了fit_generator,

  • 浅谈MyBatis原生批量插入的坑与解决方案

    目录 原生批量插入的"坑" 解决方案 分片 Demo 实战 原生批量插入分片实现 总结 前面的文章咱们讲了 MyBatis 批量插入的 3 种方法:循环单次插入.MyBatis Plus 批量插入.MyBatis 原生批量插入,详情请点击<MyBatis 批量插入数据的 3 种方法!> 但之前的文章也有不完美之处,原因在于:使用 「循环单次插入」的性能太低,使用「MyBatis Plus 批量插入」性能还行,但要额外的引入 MyBatis Plus 框架,使用「MyBati

  • 浅谈golang的json.Unmarshal的坑

    最近在golang业务开发时,遇到一个坑. 我们有个服务,会接收通用的interface对象,然后去给用户发消息.因此会涉及到把各个业务方传递过来的字符串,转成interface对象. 但是因为我的字符串里有一个数字,比如下面demo里的{"number":1234567},而且数字是7位数,在经过json.Unmarshal后,被转成了科学计数法的形式,导致私信发出的链接出现异常,结果报错了. package main import ( "encoding/json&quo

  • 浅谈MultipartFile中transferTo方法的坑

    前言:最近用SpringBoot写文件上传功能,使用参数绑定之后确实是非常的方便了. 但是,项目部署就出现了问题,搞得我一脸懵逼. 后来,才发现是因为我使用了相对路径导致的,这个绝对是一个坑人的地方,不过也说明需要学习的东西还有很多! 案例再现 @PostMapping("/uploadFile") public String uploadImg(@RequestParam("file") MultipartFile file, @RequestParam(&quo

  • 浅谈mysql8.0新特性的坑和解决办法(小结)

    一.创建用户和授权 在mysql8.0创建用户和授权和之前不太一样了,其实严格上来讲,也不能说是不一样,只能说是更严格,mysql8.0需要先创建用户和设置密码,然后才能授权. #先创建一个用户 create user 'hong'@'%' identified by '123123'; #再进行授权 grant all privileges on *.* to 'hong'@'%' with grant option; 如果还是用原来5.7的那种方式,会报错误: grant all privi

  • 浅谈sklearn中predict与predict_proba区别

    predict_proba 返回的是一个 n 行 k 列的数组,列是标签(有排序), 第 i 行 第 j 列上的数值是模型预测 第 i 个预测样本为某个标签的概率,并且每一行的概率和为1. predict 直接返回的是预测 的标签. 具体见下面示例: # conding :utf-8 from sklearn.linear_model import LogisticRegression import numpy as np x_train = np.array([[1,2,3], [1,3,4]

  • 浅谈angular懒加载的一些坑

    写在前面 最近在工作中接触到angular模块化打包加载的一些内容,感觉中间踩了一些坑,在此标记一下. 项目背景: 项目主要用到angularJs作为前端框架,项目之前发布的时候会把所有的前端脚本打包压缩到一个文件中,在页面初次访问的时候加载,造成页面初始载入缓慢,在此基础上,提出按需加载,即只有当用户访问某个模块的时候,该模块的脚本才会加载. 工具类: 项目使用grunt打包根据AMD规范,使用grunt-contrib-requirejs来压缩合并模块,同时用ocLazyLoad来完成ang

随机推荐