keras 两种训练模型方式详解fit和fit_generator(节省内存)

第一种,fit

import keras
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split

#读取数据
x_train = np.load("D:\\machineTest\\testmulPE_win7\\data_sprase.npy")[()]
y_train = np.load("D:\\machineTest\\testmulPE_win7\\lable_sprase.npy")

# 获取分类类别总数
classes = len(np.unique(y_train))

#对label进行one-hot编码,必须的
label_encoder = LabelEncoder()
integer_encoded = label_encoder.fit_transform(y_train)
onehot_encoder = OneHotEncoder(sparse=False)
integer_encoded = integer_encoded.reshape(len(integer_encoded), 1)
y_train = onehot_encoder.fit_transform(integer_encoded)

#shuffle
X_train, X_test, y_train, y_test = train_test_split(x_train, y_train, test_size=0.3, random_state=0)

model = Sequential()
model.add(Dense(units=1000, activation='relu', input_dim=784))
model.add(Dense(units=classes, activation='softmax'))
model.compile(loss='categorical_crossentropy',
    optimizer='sgd',
    metrics=['accuracy'])
model.fit(X_train, y_train, epochs=50, batch_size=128)
score = model.evaluate(X_test, y_test, batch_size=128)
# #fit参数详情
# keras.models.fit(
# self,
# x=None, #训练数据
# y=None, #训练数据label标签
# batch_size=None, #每经过多少个sample更新一次权重,defult 32
# epochs=1, #训练的轮数epochs
# verbose=1, #0为不在标准输出流输出日志信息,1为输出进度条记录,2为每个epoch输出一行记录
# callbacks=None,#list,list中的元素为keras.callbacks.Callback对象,在训练过程中会调用list中的回调函数
# validation_split=0., #浮点数0-1,将训练集中的一部分比例作为验证集,然后下面的验证集validation_data将不会起到作用
# validation_data=None, #验证集
# shuffle=True, #布尔值和字符串,如果为布尔值,表示是否在每一次epoch训练前随机打乱输入样本的顺序,如果为"batch",为处理HDF5数据
# class_weight=None, #dict,分类问题的时候,有的类别可能需要额外关注,分错的时候给的惩罚会比较大,所以权重会调高,体现在损失函数上面
# sample_weight=None, #array,和输入样本对等长度,对输入的每个特征+个权值,如果是时序的数据,则采用(samples,sequence_length)的矩阵
# initial_epoch=0, #如果之前做了训练,则可以从指定的epoch开始训练
# steps_per_epoch=None, #将一个epoch分为多少个steps,也就是划分一个batch_size多大,比如steps_per_epoch=10,则就是将训练集分为10份,不能和batch_size共同使用
# validation_steps=None, #当steps_per_epoch被启用的时候才有用,验证集的batch_size
# **kwargs #用于和后端交互
# )
#
# 返回的是一个History对象,可以通过History.history来查看训练过程,loss值等等

第二种,fit_generator(节省内存)

# 第二种,可以节省内存
'''
Created on 2018-4-11
fit_generate.txt,后面两列为lable,已经one-hot编码
1 2 0 1
2 3 1 0
1 3 0 1
1 4 0 1
2 4 1 0
2 5 1 0

'''
import keras
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
from sklearn.model_selection import train_test_split

count =1
def generate_arrays_from_file(path):
 global count
 while 1:
  datas = np.loadtxt(path,delimiter=' ',dtype="int")
  x = datas[:,:2]
  y = datas[:,2:]
  print("count:"+str(count))
  count = count+1
  yield (x,y)
x_valid = np.array([[1,2],[2,3]])
y_valid = np.array([[0,1],[1,0]])
model = Sequential()
model.add(Dense(units=1000, activation='relu', input_dim=2))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy',
    optimizer='sgd',
    metrics=['accuracy'])

model.fit_generator(generate_arrays_from_file("D:\\fit_generate.txt"),steps_per_epoch=10, epochs=2,max_queue_size=1,validation_data=(x_valid, y_valid),workers=1)
# steps_per_epoch 每执行一次steps,就去执行一次生产函数generate_arrays_from_file
# max_queue_size 从生产函数中出来的数据时可以缓存在queue队列中
# 输出如下:
# Epoch 1/2
# count:1
# count:2
#
# 1/10 [==>...........................] - ETA: 2s - loss: 0.7145 - acc: 0.3333count:3
# count:4
# count:5
# count:6
# count:7
#
# 7/10 [====================>.........] - ETA: 0s - loss: 0.7001 - acc: 0.4286count:8
# count:9
# count:10
# count:11
#
# 10/10 [==============================] - 0s 36ms/step - loss: 0.6960 - acc: 0.4500 - val_loss: 0.6794 - val_acc: 0.5000
# Epoch 2/2
#
# 1/10 [==>...........................] - ETA: 0s - loss: 0.6829 - acc: 0.5000count:12
# count:13
# count:14
# count:15
#
# 5/10 [==============>...............] - ETA: 0s - loss: 0.6800 - acc: 0.5000count:16
# count:17
# count:18
# count:19
# count:20
#
# 10/10 [==============================] - 0s 11ms/step - loss: 0.6766 - acc: 0.5000 - val_loss: 0.6662 - val_acc: 0.5000

补充知识:

自动生成数据还可以继承keras.utils.Sequence,然后写自己的生成数据类:

keras数据自动生成器,继承keras.utils.Sequence,结合fit_generator实现节约内存训练

#coding=utf-8
'''
Created on 2018-7-10
'''
import keras
import math
import os
import cv2
import numpy as np
from keras.models import Sequential
from keras.layers import Dense

class DataGenerator(keras.utils.Sequence):

 def __init__(self, datas, batch_size=1, shuffle=True):
  self.batch_size = batch_size
  self.datas = datas
  self.indexes = np.arange(len(self.datas))
  self.shuffle = shuffle

 def __len__(self):
  #计算每一个epoch的迭代次数
  return math.ceil(len(self.datas) / float(self.batch_size))

 def __getitem__(self, index):
  #生成每个batch数据,这里就根据自己对数据的读取方式进行发挥了
  # 生成batch_size个索引
  batch_indexs = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
  # 根据索引获取datas集合中的数据
  batch_datas = [self.datas[k] for k in batch_indexs]

  # 生成数据
  X, y = self.data_generation(batch_datas)

  return X, y

 def on_epoch_end(self):
  #在每一次epoch结束是否需要进行一次随机,重新随机一下index
  if self.shuffle == True:
   np.random.shuffle(self.indexes)

 def data_generation(self, batch_datas):
  images = []
  labels = []

  # 生成数据
  for i, data in enumerate(batch_datas):
   #x_train数据
   image = cv2.imread(data)
   image = list(image)
   images.append(image)
   #y_train数据
   right = data.rfind("\\",0)
   left = data.rfind("\\",0,right)+1
   class_name = data[left:right]
   if class_name=="dog":
    labels.append([0,1])
   else:
    labels.append([1,0])
  #如果为多输出模型,Y的格式要变一下,外层list格式包裹numpy格式是list[numpy_out1,numpy_out2,numpy_out3]
  return np.array(images), np.array(labels)

# 读取样本名称,然后根据样本名称去读取数据
class_num = 0
train_datas = []
for file in os.listdir("D:/xxx"):
 file_path = os.path.join("D:/xxx", file)
 if os.path.isdir(file_path):
  class_num = class_num + 1
  for sub_file in os.listdir(file_path):
   train_datas.append(os.path.join(file_path, sub_file))

# 数据生成器
training_generator = DataGenerator(train_datas)

#构建网络
model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=784))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy',
    optimizer='sgd',
    metrics=['accuracy'])
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])

model.fit_generator(training_generator, epochs=50,max_queue_size=10,workers=1)

以上这篇keras 两种训练模型方式详解fit和fit_generator(节省内存)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • Keras之fit_generator与train_on_batch用法

    关于Keras中,当数据比较大时,不能全部载入内存,在训练的时候就需要利用train_on_batch或fit_generator进行训练了. 两者均是利用生成器,每次载入一个batch-size的数据进行训练. 那么fit_generator与train_on_batch该用哪一个呢? train_on_batch(self, x, y, class_weight=None, sample_weight=None) fit_generator(self, generator, samples_

  • 在keras中model.fit_generator()和model.fit()的区别说明

    首先Keras中的fit()函数传入的x_train和y_train是被完整的加载进内存的,当然用起来很方便,但是如果我们数据量很大,那么是不可能将所有数据载入内存的,必将导致内存泄漏,这时候我们可以用fit_generator函数来进行训练. keras中文文档 fit fit(x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.0, validation_data=N

  • 浅谈keras2 predict和fit_generator的坑

    1.使用predict时,必须设置batch_size,否则效率奇低. 查看keras文档中,predict函数原型: predict(self, x, batch_size=32, verbose=0) 说明: 只使用batch_size=32,也就是说每次将batch_size=32的数据通过PCI总线传到GPU,然后进行预测.在一些问题中,batch_size=32明显是非常小的.而通过PCI传数据是非常耗时的. 所以,使用的时候会发现预测数据时效率奇低,其原因就是batch_size太小

  • 基于Keras 循环训练模型跑数据时内存泄漏的解决方式

    在使用完模型之后,添加这两行代码即可清空之前model占用的内存: import tensorflow as tf from keras import backend as K K.clear_session() tf.reset_default_graph() 补充知识:keras 多个模型测试阶段速度越来越慢问题的解决方法 问题描述 在实际应用或比赛中,经常会用到交叉验证(10倍或5倍)来提高泛化能力,这样在预测时需要加载多个模型.常用的方法为 mods = [] from keras.ut

  • 浅谈keras通过model.fit_generator训练模型(节省内存)

    前言 前段时间在训练模型的时候,发现当训练集的数量过大,并且输入的图片维度过大时,很容易就超内存了,举个简单例子,如果我们有20000个样本,输入图片的维度是224x224x3,用float32存储,那么如果我们一次性将全部数据载入内存的话,总共就需要20000x224x224x3x32bit/8=11.2GB 这么大的内存,所以如果一次性要加载全部数据集的话是需要很大内存的. 如果我们直接用keras的fit函数来训练模型的话,是需要传入全部训练数据,但是好在提供了fit_generator,

  • keras 两种训练模型方式详解fit和fit_generator(节省内存)

    第一种,fit import keras from keras.models import Sequential from keras.layers import Dense import numpy as np from sklearn.preprocessing import LabelEncoder from sklearn.preprocessing import OneHotEncoder from sklearn.model_selection import train_test_s

  • 基于String变量的两种创建方式(详解)

    在java中,有两种创建String类型变量的方式: String str01="abc";//第一种方式 String str02=new String("abc")://第二种方式 第一种方式创建String变量时,首先查找JVM方法区的字符串常量池是否存在存放"abc"的地址,如果存在,则将该变量指向这个地址,不存在,则在方法区创建一个存放字面值"abc"的地址. 第二种方式创建String变量时,在堆中创建一个存放&q

  • Thinkphp事件机制两种实现方式详解

    目录 一.通过监听 二.通过订阅 1.创建订阅类 2.配置监听 3.触发监听 4.处理监听逻辑 4.1 自动绑定 4.2 手动绑定 总结 事件机制的实现有两种途径:通过监听.通过订阅 一.通过监听 1.创建监听类:在命令行模式下进入框架根目录执行 php think make:listener <自定义的类名> 例如: php think make:listener UserListener 执行之后将在<框架根目录>\app\listener\下生成UserListener这个类

  • Java动态代理的两种实现方式详解【附相关jar文件下载】

    本文实例讲述了Java动态代理的两种实现方式.分享给大家供大家参考,具体如下: 一说到动态代理,我们第一个想到肯定是大名鼎鼎的Spring AOP了.在AOP的源码中用到了两种动态代理来实现拦截切入功能:jdk动态代理和cglib动态代理.两种方法同时存在,各有优劣.jdk动态代理是由java内部的反射机制来实现的,cglib动态代理是通过继承来实现的,底层则是借助asm(Java 字节码操控框架)来实现的(采用字节码的方式,给A类创建一个子类B,子类B使用方法拦截的技术拦截所以父类的方法调用)

  • C语言中字符串的两种定义方式详解

    目录 方式1 方式2 总结 我们知道C语言中是没有字符串这种数据类型的,我们只能依靠数组进行存储,即字符数组,而我们定义并且初始化数组有两种方式.下面将给大家介绍这两种方式并且介绍这两种方式的区别: 方式1 前两种是正确的定义方式,第一种之所以没有指定字符数组长度的原因是编译器能够自己推断出其长度,无需程序员自己设定,这也是我们比较推荐的一种定义方式,但注意内存长度编译器一经判定就无法再次更改,接下来我们分析一下第三种编译器为什么会出现乱码. 相信大家都知道,字符串是以'\0'字符为结束标志的,

  • jQuery中JSONP的两种实现方式详解

    前台代码如下: 后台Action代码如下: 运行后就可以看到结果了.我追踪了下后台ProcessCallback代码,如下图: 可以看到jsonCallback的值为"jQuery17104721....",它是前端传给远程服务器后台Action的.这里 jQuery171..表示的是jQuery的版本,可以简单地将这个理解为JSONP类型请求回调函数,jQuery在我们每次指定Ajax请求方式为 JSONP时都会生成这么一个JSONP回调函数.虽然jQuery会自动帮我们生成一个回调

  • 关于react-router的几种配置方式详解

    本文介绍关于react-router的几种配置方式详解,分享给大家,具体如下: 路由的概念 路由的作用就是将url和函数进行映射,在单页面应用中路由是必不可少的部分,路由配置就是一组指令,用来告诉router如何匹配url,以及对应的函数映射,即执行对应的代码. react-router 每一门JS框架都会有自己定制的router框架,react-router就是react开发应用御用的路由框架,目前它的最新的官方版本为4.1.2.本文给大家介绍的是react-router相比于其他router

  • Python selenium 三种等待方式详解(必会)

    很多人在群里问,这个下拉框定位不到.那个弹出框定位不到-各种定位不到,其实大多数情况下就是两种问题:1 有frame,2 没有加等待.殊不知,你的代码运行速度是什么量级的,而浏览器加载渲染速度又是什么量级的,就好比闪电侠和凹凸曼约好去打怪兽,然后闪电侠打完回来之后问凹凸曼你为啥还在穿鞋没出门?凹凸曼分分中内心一万只羊驼飞过,欺负哥速度慢,哥不跟你玩了,抛个异常撂挑子了. 那么怎么才能照顾到凹凸曼缓慢的加载速度呢?只有一个办法,那就是等喽.说到等,又有三种等法,且听博主一一道来: 1. 强制等待

  • Android开发之基本控件和四种布局方式详解

    Android中的控件的使用方式和iOS中控件的使用方式基本相同,都是事件驱动.给控件添加事件也有接口回调和委托代理的方式.今天这篇博客就总结一下Android中常用的基本控件以及布局方式.说到布局方式Android和iOS还是区别挺大的,在iOS中有Frame绝对布局和AutoLayout相对布局.而在Android中的布局方式就比较丰富了,今天博客中会介绍四种常用的布局方式.先总结一下控件,然后再搞一搞基本方式,开发环境还是用的Mac下的Android Studio.开始今天的正题, 虽然A

  • Spring Data Jpa的四种查询方式详解

    这篇文章主要介绍了Spring Data Jpa的四种查询方式详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下 一.调用接口的方式 1.基本介绍 通过调用接口里的方法查询,需要我们自定义的接口继承Spring Data Jpa规定的接口 public interface UserDao extends JpaRepository<User, Integer>, JpaSpecificationExecutor<User> 使用这

随机推荐