sklearn和keras的数据切分与交叉验证的实例详解

在训练深度学习模型的时候,通常将数据集切分为训练集和验证集.Keras提供了两种评估模型性能的方法:

使用自动切分的验证集

使用手动切分的验证集

一.自动切分

在Keras中,可以从数据集中切分出一部分作为验证集,并且在每次迭代(epoch)时在验证集中评估模型的性能.

具体地,调用model.fit()训练模型时,可通过validation_split参数来指定从数据集中切分出验证集的比例.

# MLP with automatic validation set
from keras.models import Sequential
from keras.layers import Dense
import numpy
# fix random seed for reproducibility
numpy.random.seed(7)
# load pima indians dataset
dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]
# create model
model = Sequential()
model.add(Dense(12, input_dim=8, activation='relu'))
model.add(Dense(8, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
# Compile model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# Fit the model
model.fit(X, Y, validation_split=0.33, epochs=150, batch_size=10)

validation_split:0~1之间的浮点数,用来指定训练集的一定比例数据作为验证集。验证集将不参与训练,并在每个epoch结束后测试的模型的指标,如损失函数、精确度等。

注意,validation_split的划分在shuffle之前,因此如果你的数据本身是有序的,需要先手工打乱再指定validation_split,否则可能会出现验证集样本不均匀。

二.手动切分

Keras允许在训练模型的时候手动指定验证集.

例如,用sklearn库中的train_test_split()函数将数据集进行切分,然后在keras的model.fit()的时候通过validation_data参数指定前面切分出来的验证集.

# MLP with manual validation set
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import train_test_split
import numpy
# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)
# load pima indians dataset
dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]
# split into 67% for train and 33% for test
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.33, random_state=seed)
# create model
model = Sequential()
model.add(Dense(12, input_dim=8, activation='relu'))
model.add(Dense(8, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
# Compile model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# Fit the model
model.fit(X_train, y_train, validation_data=(X_test,y_test), epochs=150, batch_size=10)

三.K折交叉验证(k-fold cross validation)

将数据集分成k份,每一轮用其中(k-1)份做训练而剩余1份做验证,以这种方式执行k轮,得到k个模型.将k次的性能取平均,作为该算法的整体性能.k一般取值为5或者10.

优点:能比较鲁棒性地评估模型在未知数据上的性能.

缺点:计算复杂度较大.因此,在数据集较大,模型复杂度较高,或者计算资源不是很充沛的情况下,可能不适用,尤其是在训练深度学习模型的时候.

sklearn.model_selection提供了KFold以及RepeatedKFold, LeaveOneOut, LeavePOut, ShuffleSplit, StratifiedKFold, GroupKFold, TimeSeriesSplit等变体.

下面的例子中用的StratifiedKFold采用的是分层抽样,它保证各类别的样本在切割后每一份小数据集中的比例都与原数据集中的比例相同.

# MLP for Pima Indians Dataset with 10-fold cross validation
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import StratifiedKFold
import numpy
# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)
# load pima indians dataset
dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]
# define 10-fold cross validation test harness
kfold = StratifiedKFold(n_splits=10, shuffle=True, random_state=seed)
cvscores = []
for train, test in kfold.split(X, Y):
 # create model
  model = Sequential()
  model.add(Dense(12, input_dim=8, activation='relu'))
  model.add(Dense(8, activation='relu'))
  model.add(Dense(1, activation='sigmoid'))
  # Compile model
  model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
  # Fit the model
  model.fit(X[train], Y[train], epochs=150, batch_size=10, verbose=0)
  # evaluate the model
  scores = model.evaluate(X[test], Y[test], verbose=0)
  print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))
  cvscores.append(scores[1] * 100)
print("%.2f%% (+/- %.2f%%)" % (numpy.mean(cvscores), numpy.std(cvscores)))

补充知识:训练集,验证集和测试集

训练集:通过最小化目标函数(损失函数 + 正则项),用来训练模型的参数。当目标函数最小化时,完成对模型的训练。

验证集:用来选择模型的阶数。目标函数最小的模型对应的阶数,为模型的最终选择的阶数。

注:

1. 验证集会在训练过程中,反复使用,机器学习中作为选择不同模型的评判标准,深度学习中作为选择网络层数和每层节点数的评判标准。

2. 验证集的使用并非必不可少,如果网络的层数和节点数已经确定,则不需要这一步操作。

测试集:评估模型的泛化能力。根据选择的已经训练好的模型,评估它的泛化能力。

注:

测试集评判的是最终训练好的模型的泛化能力,只进行一次评判。

以上这篇sklearn和keras的数据切分与交叉验证的实例详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • keras Lambda自定义层实现数据的切片方式,Lambda传参数

    1.代码如下: import numpy as np from keras.models import Sequential from keras.layers import Dense, Activation,Reshape from keras.layers import merge from keras.utils.visualize_util import plot from keras.layers import Input, Lambda from keras.models impo

  • 深入浅析Python 中的sklearn模型选择

    1.主要功能如下: 1.classification分类 2.Regression回归 3.Clustering聚类 4.Dimensionality reduction降维 5.Model selection模型选择 6.Preprocessing预处理 2.主要模块分类: 1.sklearn.base: Base classes and utility function基础实用函数 2.sklearn.cluster: Clustering聚类 3.sklearn.cluster.biclu

  • Python sklearn KFold 生成交叉验证数据集的方法

    源起: 1.我要做交叉验证,需要每个训练集和测试集都保持相同的样本分布比例,直接用sklearn提供的KFold并不能满足这个需求. 2.将生成的交叉验证数据集保存成CSV文件,而不是直接用sklearn训练分类模型. 3.在编码过程中有一的误区需要注意: 这个sklearn官方给出的文档 >>> import numpy as np >>> from sklearn.model_selection import KFold >>> X = [&quo

  • sklearn和keras的数据切分与交叉验证的实例详解

    在训练深度学习模型的时候,通常将数据集切分为训练集和验证集.Keras提供了两种评估模型性能的方法: 使用自动切分的验证集 使用手动切分的验证集 一.自动切分 在Keras中,可以从数据集中切分出一部分作为验证集,并且在每次迭代(epoch)时在验证集中评估模型的性能. 具体地,调用model.fit()训练模型时,可通过validation_split参数来指定从数据集中切分出验证集的比例. # MLP with automatic validation set from keras.mode

  • Spring boot jpa 删除数据和事务管理的问题实例详解

    今天我们介绍的是jpa删除和事务的一些坑,接下来看看具体内容. 业务场景(这是一个在线考试系统)和代码:根据问题的id删除答案 repository层: int deleteByQuestionId(Integer questionId); service 层: public void deleteChoiceAnswerByQuestionId(Integer questionId) { choiceAnswerRepository.deleteByQuestionId(questionId)

  • java 数据结构中栈和队列的实例详解

    java 数据结构中栈和队列的实例详解 栈和队列是两种重要的线性数据结构,都是在一个特定的范围的存储单元中的存储数据.与线性表相比,它们的插入和删除操作收到更多的约束和限定,又被称为限定性的线性表结构.栈是先进后出FILO,队列是先进先出FIFO,但是有的数据结构按照一定的条件排队数据的队列,这时候的队列属于特殊队列,不一定按照上面的原则. 实现栈:采用数组和链表两种方法来实现栈 链表方法: package com.cl.content01; /* * 使用链表来实现栈 */ public cl

  • mysql查找删除重复数据并只保留一条实例详解

    有这样一张表,表数据及结果如下: school_id school_name total_student test_takers 1239 Abraham Lincoln High School 55 50 1240 Abraham Lincoln High School 70 35 1241 Acalanes High School 120 89 1242 Academy Of The Canyons 30 30 1243 Agoura High School 89 40 1244 Agour

  • js中自定义react数据验证组件实例详解

    我们在做前端表单提交时,经常会遇到要对表单中的数据进行校验的问题.如果用户提交的数据不合法,例如格式不正确.非数字类型.超过最大长度.是否必填项.最大值和最小值等等,我们需要在相应的地方给出提示信息.如果用户修正了数据,我们还要将提示信息隐藏起来. 有一些现成的插件可以让你非常方便地实现这一功能,如果你使用的是knockout框架,那么你可以借助于Knockout-Validation这一插件.使用起来很简单,例如我下面的这一段代码: ko.validation.locale('zh-CN');

  • Yii2框架数据验证操作实例详解

    本文实例讲述了Yii2框架数据验证操作.分享给大家供大家参考,具体如下: 一.场景 什么情况下需要使用场景呢?当一个模型需要在不同情境中使用时,若不同情境下需要的数据表字段和数据验证规则有所不同,则需要定义多个场景来区分不同使用情境.例如,用户注册的时候需要填写email,登录的时候则不需要,这时就需要定义两个不同场景加以区分. 默认情况下模型的场景是由rules()方法申明的验证规则中使用到的场景决定的,也可以通过覆盖scenarios()方法来更具体地定义模型的所有场景,例如: public

  • SQL数据去重的3种方法实例详解

    目录 1.使用distinct去重 2.使用group by 3.使用ROW_NUMBER() OVER 或 GROUP BY 和 COLLECT_SET/COLLECT_LIST 3.1 ROW_NUMBER() OVER 3.2 GROUP BY 和 COLLECT_SET/COLLECT_LIST distinct与group by的去重方面的区别 使用去重distinct方法的示例详解 总结 1.使用distinct去重 distinct用来查询不重复记录的条数,用count(disti

  • PHP数据的提交与过滤基本操作实例详解

    本文实例讲述了PHP数据的提交与过滤基本操作.分享给大家供大家参考,具体如下: 1.php提交数据过滤的基本原则 1)提交变量进数据库时,我们必须使用addslashes()进行过滤,像我们的注入问题,一个addslashes()也就搞定了.其实在涉及到变量取值时,intval()函数对字符串的过滤也是个不错的选择. 2)在php.ini中开启magic_quotes_gpc和magic_quotes_runtime.magic_quotes_gpc可以把get,post,cookie里的引号变

  • Oracle数据泵的导入与导出实例详解

    前言 今天王子要分享的内容是关于Oracle的一个实战内容,Oracle的数据泵. 网上有很多关于此的内容,但很多都是复制粘贴别人的,导致很多小伙伴想要使用的时候不能直接上手,所以这篇文章一定能让你更清晰的理解数据泵. 开始之前王子先介绍一下自己的环境,这里使用的是比较常用的WIN10系统,Oracle数据库也是安装在本机上的,环境比较简单. 数据泵的导入 导入的数据文件可能是别人导出给你的,也可能是你自己导出的,王子这里就是别人导出的,文件名字是YD.DMP. 在进行操作之前,一定要问清楚表空

  • get post jsonp三种数据交互形式实例详解

    一.get请求 1.引入 vue.js 和 vue-resource.js , 准备一个按钮 <input type="button" value="按钮" @click="get()"/> //点击按钮请求数据函数get() 2.准备一个txt文件 welcome vue 3.编写js代码 <script> window.onload=function(){ new Vue({ el:'body', //主体为body,

随机推荐