Keras 利用sklearn的ROC-AUC建立评价函数详解

我就废话不多说了,大家还是直接看代码吧!

# 利用sklearn自建评价函数
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from keras.callbacks import Callback

class RocAucEvaluation(Callback):
 def __init__(self, validation_data=(), interval=1):
 super(Callback, self).__init__()
 self.interval = interval
 self.x_val,self.y_val = validation_data
 def on_epoch_end(self, epoch, log={}):
 if epoch % self.interval == 0:
  y_pred = self.model.predict(self.x_val, verbose=0)
  score = roc_auc_score(self.y_val, y_pred)
  print('\n ROC_AUC - epoch:%d - score:%.6f \n' % (epoch+1, score))

x_train,y_train,x_label,y_label = train_test_split(train_feature, train_label, train_size=0.95, random_state=233)
RocAuc = RocAucEvaluation(validation_data=(y_train,y_label), interval=1)

hist = model.fit(x_train, x_label, batch_size=batch_size, epochs=epochs, validation_data=(y_train, y_label), callbacks=[RocAuc], verbose=2)

补充知识:keras用auc做metrics以及早停

我就废话不多说了,大家还是直接看代码吧!

import tensorflow as tf
from sklearn.metrics import roc_auc_score

def auroc(y_true, y_pred):
 return tf.py_func(roc_auc_score, (y_true, y_pred), tf.double)
# Build Model...
model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy', auroc])

完整例子:

def auc(y_true, y_pred):
 auc = tf.metrics.auc(y_true, y_pred)[1]
 K.get_session().run(tf.local_variables_initializer())
 return auc

def create_model_nn(in_dim,layer_size=200):
 model = Sequential()
 model.add(Dense(layer_size,input_dim=in_dim, kernel_initializer='normal'))
 model.add(BatchNormalization())
 model.add(Activation('relu'))
 model.add(Dropout(0.3))
 for i in range(2):
 model.add(Dense(layer_size))
 model.add(BatchNormalization())
 model.add(Activation('relu'))
 model.add(Dropout(0.3))
 model.add(Dense(1, activation='sigmoid'))
 adam = optimizers.Adam(lr=0.01)
 model.compile(optimizer=adam,loss='binary_crossentropy',metrics = [auc])
 return model
####cv train
folds = StratifiedKFold(n_splits=5, shuffle=False, random_state=15)
oof = np.zeros(len(df_train))
predictions = np.zeros(len(df_test))
for fold_, (trn_idx, val_idx) in enumerate(folds.split(df_train.values, target2.values)):
 print("fold n°{}".format(fold_))
 X_train = df_train.iloc[trn_idx][features]
 y_train = target2.iloc[trn_idx]
 X_valid = df_train.iloc[val_idx][features]
 y_valid = target2.iloc[val_idx]
 model_nn = create_model_nn(X_train.shape[1])
 callback = EarlyStopping(monitor="val_auc", patience=50, verbose=0, mode='max')
 history = model_nn.fit(X_train, y_train, validation_data = (X_valid ,y_valid),epochs=1000,batch_size=64,verbose=0,callbacks=[callback])
 print('\n Validation Max score : {}'.format(np.max(history.history['val_auc'])))
 predictions += model_nn.predict(df_test[features]).ravel()/folds.n_splits

以上这篇Keras 利用sklearn的ROC-AUC建立评价函数详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • 升级keras解决load_weights()中的未定义skip_mismatch关键字问题

    1.问题描述 在用yolov3训练自己的数据集时,尝试加载预训练的权重,在冻结前154层的基础上,利用自己的数据集finetune. 出现如下错误: load_weights(),got an unexpected keyword argument skip_mismatch 2.解决方法 因为keras旧版本没有这一定义,在新的版本中有这一关键字的定义,因此,更新keras版本至2.1.5即可解决. source activate env pip uninstall keras pip ins

  • 浅谈cv2.imread()和keras.preprocessing中的image.load_img()区别

    1.image.load_img() from keras.preprocessing import image img_keras = image.load_img('./original/dog/880.jpg') print(img_keras) img_keras = image.img_to_array(img_keras) print(img_keras[:,1,1]) 效果如下: <PIL.JpegImagePlugin.JpegImageFile image mode=RGB s

  • keras 读取多标签图像数据方式

    我所接触的多标签数据,主要包括两类: 1.一张图片属于多个标签,比如,data:一件蓝色的上衣图片.jpg,label:蓝色,上衣.其中label包括两类标签,label1第一类:上衣,裤子,外套.label2第二类,蓝色,黑色,红色.这样两个输出label1,label2都是是分类,我们可以直接把label1和label2整合为一个label,直接编码,比如[蓝色,上衣]编码为[011011].这样模型的输出也只需要一个输出.实现了多分类. 2.一张图片属于多个标签,但是几个标签不全是分类.比

  • Python实现Keras搭建神经网络训练分类模型教程

    我就废话不多说了,大家还是直接看代码吧~ 注释讲解版: # Classifier example import numpy as np # for reproducibility np.random.seed(1337) # from keras.datasets import mnist from keras.utils import np_utils from keras.models import Sequential from keras.layers import Dense, Act

  • Keras 利用sklearn的ROC-AUC建立评价函数详解

    我就废话不多说了,大家还是直接看代码吧! # 利用sklearn自建评价函数 from sklearn.model_selection import train_test_split from sklearn.metrics import roc_auc_score from keras.callbacks import Callback class RocAucEvaluation(Callback): def __init__(self, validation_data=(), interv

  • Java 利用dom方式读取、创建xml详解及实例代码

    Java 利用dom方式读取.创建xml详解 1.创建一个接口 XmlInterface.Java public interface XmlInterface { /** * 建立XML文档 * @param fileName 文件全路径名称 */ public void createXml(String fileName); /** * 解析XML文档 * @param fileName 文件全路径名称 */ public void parserXml(String fileName); }

  • 与Django结合利用模型对上传图片预测的实例详解

    1 预处理 (1)对上传的图片进行预处理成100*100大小 def prepicture(picname): img = Image.open('./media/pic/' + picname) new_img = img.resize((100, 100), Image.BILINEAR) new_img.save(os.path.join('./media/pic/', os.path.basename(picname))) (2)将图片转化成数组 def read_image2(file

  • 利用Python创建位置生成器的示例详解

    目录 介绍 开始 步骤 创建训练数据集 创建测试数据集 将合成图像转换回坐标 放在一起 结论 介绍 在这篇文章中,我们将探索如何在美国各地城市的地图数据和公共电动自行车订阅源上训练一个快速生成的对抗网络(GAN)模型. 然后,我们可以通过为包括东京在内的世界各地城市创建合成数据集来测试该模型的学习和概括能力. git clone https://github.com/gretelai/GAN-location-generator.git 在之前的一篇博客中,我们根据电子自行车订阅源中的精确位置数

  • Qt利用QDrag实现拖拽拼图功能详解

    目录 一.项目介绍 二.项目基本配置 三.UI界面设置 四.主程序实现 4.1 main.cpp 4.1 mainwindow.h头文件 4.2 mainwindow.cpp源文件 4.3 PiecesList类 4.4 PuzzleWidget类 五.效果演示 一.项目介绍 本文介绍利用QDrag类实现拖拽拼图功能.左边是打散的图,拖动到右边进行复现,此外程序还支持手动拖入原图片. 二.项目基本配置 新建一个Qt案例,项目名称为“puzzle”,基类选择“QMainWindow”,取消选中创建

  • java 与testng利用XML做数据源的数据驱动示例详解

    java 与testng利用XML做数据源的数据驱动示例详解 testng的功能很强大,利用@DataProvider可以做数据驱动,数据源文件可以是EXCEL,XML,YAML,甚至可以是TXT文本.在这以XML为例: 备注:@DataProvider的返回值类型只能是Object[][]与Iterator<Object>[] TestData.xml: <?xml version="1.0" encoding="UTF-8"?> <

  • Keras官方中文文档:性能评估Metrices详解

    能评估 使用方法 性能评估模块提供了一系列用于模型性能评估的函数,这些函数在模型编译时由metrics关键字设置 性能评估函数类似与目标函数, 只不过该性能的评估结果讲不会用于训练. 可以通过字符串来使用域定义的性能评估函数 model.compile(loss='mean_squared_error', optimizer='sgd', metrics=['mae', 'acc']) 也可以自定义一个Theano/TensorFlow函数并使用之 from keras import metri

  • C#利用ASP.NET Core开发学生管理系统详解

    目录 涉及知识点 创建项目 登录模块 1. 创建控制器--LoginController 2. 创建登录视图 3. 创建用户模型 4. 创建数据库操作DataContext 5. 创建数据库和表并构造数据 6. 添加数据库连接配置 7. 添加注入信息 8. 运行测试 随着技术的进步,跨平台开发已经成为了标配,在此大背景下,ASP.NET Core也应运而生.本文主要利用ASP.NET Core开发一个学生管理系统为例,简述ASP.NET Core开发的常见知识点,仅供学习分享使用,如有不足之处,

  • 利用OpenCV实现YOLO对象检测方法详解

    目录 前言 什么是YOLO物体检测器? 项目结构 检测图像 检测视频 前言 本文将教你如何使用YOLOV3对象检测器.OpenCV和Python实现对图像和视频流的检测.用到的文件有yolov3.weights.yolov3.cfg.coco.names,这三个文件的github链接如下: GitHub - pjreddie/darknet: Convolutional Neural Networks https://pjreddie.com/media/files/yolov3.weights

  • Java8利用Stream实现列表去重的方法详解

    目录 一. Stream 的distinct()方法 1.1 对于 String 列表的去重 1.2 对于实体类列表的去重 二. 根据 List<Object> 中 Object 某个属性去重 2.1 新建一个列表出来 2.2 通过 filter() 方法 一. Stream 的distinct()方法 distinct()是Java 8 中 Stream 提供的方法,返回的是由该流中不同元素组成的流.distinct()使用 hashCode() 和 eqauls() 方法来获取不同的元素.

随机推荐