Python利用Seaborn绘制多标签的混淆矩阵

Seaborn - 绘制多标签的混淆矩阵、召回、精准、F1

导入seaborn\matplotlib\scipy\sklearn等包:

import seaborn as sns
from matplotlib import pyplot as plt
from scipy.special import softmax
from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, f1_score

sns.set_theme(color_codes=True)

从dataframe中,获取y_true(真实标签)和y_pred(预测标签):

y_true = df["target"]
y_pred = df['prediction']

计算验证数据整体的准确率acc、精准率precision、召回率recall、F1,使用加权模式average=‘weighted’:

# 准确率acc,精准precision,召回recall,F1
acc = accuracy_score(df["target"], df['prediction'])
precision = precision_score(y_true, y_pred, average='weighted')
recall = recall_score(y_true, y_pred, average='weighted')
f1 = f1_score(y_true, y_pred, average='weighted')
print(f'[Info] acc: {acc}, precision: {precision}, recall: {recall}, f1: {f1}')

计算混淆矩阵:

# 横坐标是真实类别数,纵坐标是预测类别数
cf_matrix = confusion_matrix(y_true, y_pred)

5类矩阵的绘制方案,混淆矩阵、百分比的混淆矩阵、召回矩阵、精准矩阵、F1矩阵:

  • 混淆矩阵是计数,百分比的混淆矩阵是占比
  • 召回矩阵是,每行的和是1,每行代表真实类别数,占比就是召回
  • 精准矩阵是,每列的和是1,每列代表预测列表数,占比就是精准
  • F1矩阵是按照 2PR/(P+R),注意为0的情况,需要补0,使用np.divide(a, b, out=np.zeros_like(a), where=(b != 0))

代码如下:

# 横坐标是真实类别数,纵坐标是预测类别数
cf_matrix = confusion_matrix(y_true, y_pred)

figure, axes = plt.subplots(2, 2, figsize=(16*1.25, 16))

# 混淆矩阵
ax = sns.heatmap(cf_matrix, annot=True, fmt='g', ax=axes[0][0], cmap='Blues')
ax.title.set_text("Confusion Matrix")
ax.set_xlabel("y_pred")
ax.set_ylabel("y_true")
# plt.savefig(csv_path.replace(".csv", "_cf_matrix.png"))
# plt.show()

# 混淆矩阵 - 百分比
cf_matrix = confusion_matrix(y_true, y_pred)
ax = sns.heatmap(cf_matrix / np.sum(cf_matrix), annot=True, ax=axes[0][1], fmt='.2%', cmap='Blues')
ax.title.set_text("Confusion Matrix (percent)")
ax.set_xlabel("y_pred")
ax.set_ylabel("y_true")
# plt.savefig(csv_path.replace(".csv", "_cf_matrix_p.png"))
# plt.show()

# 召回矩阵,行和为1
sum_true = np.expand_dims(np.sum(cf_matrix, axis=1), axis=1)
precision_matrix = cf_matrix / sum_true
ax = sns.heatmap(precision_matrix, annot=True, fmt='.2%', ax=axes[1][0], cmap='Blues')
ax.title.set_text("Precision Matrix")
ax.set_xlabel("y_pred")
ax.set_ylabel("y_true")
# plt.savefig(csv_path.replace(".csv", "_recall.png"))
# plt.show()

# 精准矩阵,列和为1
sum_pred = np.expand_dims(np.sum(cf_matrix, axis=0), axis=0)
recall_matrix = cf_matrix / sum_pred
ax = sns.heatmap(recall_matrix, annot=True, fmt='.2%', ax=axes[1][1], cmap='Blues')
ax.title.set_text("Recall Matrix")
ax.set_xlabel("y_pred")
ax.set_ylabel("y_true")
# plt.savefig(csv_path.replace(".csv", "_precision.png"))
# plt.show()

# 绘制4张图
plt.autoscale(enable=False)
plt.savefig(csv_path.replace(".csv", "_all.png"), bbox_inches='tight', pad_inches=0.2)
plt.show()

# F1矩阵
a = 2 * precision_matrix * recall_matrix
b = precision_matrix + recall_matrix
f1_matrix = np.divide(a, b, out=np.zeros_like(a), where=(b != 0))
ax = sns.heatmap(f1_matrix, annot=True, fmt='.2%', cmap='Blues')
ax.title.set_text("F1 Matrix")
ax.set_xlabel("y_pred")
ax.set_ylabel("y_true")
plt.savefig(csv_path.replace(".csv", "_f1.png"))
plt.show()

输出混淆矩阵、混淆矩阵(百分比)、召回矩阵、精准矩阵:

F1 Score:

到此这篇关于Python利用Seaborn绘制多标签的混淆矩阵的文章就介绍到这了,更多相关Python Seaborn混淆矩阵内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • 利用python中的matplotlib打印混淆矩阵实例

    前面说过混淆矩阵是我们在处理分类问题时,很重要的指标,那么如何更好的把混淆矩阵给打印出来呢,直接做表或者是前端可视化,小编曾经就尝试过用前端(D5)做出来,然后截图,显得不那么好看.. 代码: import itertools import matplotlib.pyplot as plt import numpy as np def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cma

  • 使用Python和scikit-learn创建混淆矩阵的示例详解

    目录 一.混淆矩阵概述 1.示例1 2.示例2 二.使用Scikit-learn 创建混淆矩阵 1.相应软件包 2.生成示例数据集 3.训练一个SVM 4.生成混淆矩阵 5.可视化边界 一.混淆矩阵概述 在训练了有监督的机器学习模型(例如分类器)之后,您想知道它的工作情况. 这通常是通过将一小部分称为测试集的数据分开来完成的,该数据用作模型以前从未见过的数据. 如果它在此数据集上表现良好,那么该模型很可能在其他数据上也表现良好 - 当然,如果它是从与您的测试集相同的分布中采样的. 现在,当您测试

  • Python实现两种多分类混淆矩阵

    目录 1.什么是混淆矩阵 2.分类模型评价指标 3.两种多分类混淆矩阵 3.1直接打印出每一个类别的分类准确率. 3.2打印具体的分类结果的数值 4.总结 1.什么是混淆矩阵 深度学习中,混淆矩阵是ROC曲线绘制的基础,同时它也是衡量分类型模型准确度中最基本,最直观,计算最简单的方法.它可以直观地了解分类模型在每一类样本里面表现,常作为模型评估的一部分.它可以非常容易的表明多个类别是否有混淆(也就是一个class被预测成另一个class). 首先要明确几个概念: T或者F:该样本 是否被正确分类

  • 详解使用python绘制混淆矩阵(confusion_matrix)

    Summary 涉及到分类问题,我们经常需要通过可视化混淆矩阵来分析实验结果进而得出调参思路,本文介绍如何利用python绘制混淆矩阵(confusion_matrix),本文只提供代码,给出必要注释. Code​ # -*-coding:utf-8-*- from sklearn.metrics import confusion_matrix import matplotlib.pyplot as plt import numpy as np #labels表示你不同类别的代号,比如这里的de

  • python sklearn包——混淆矩阵、分类报告等自动生成方式

    preface:做着最近的任务,对数据处理,做些简单的提特征,用机器学习算法跑下程序得出结果,看看哪些特征的组合较好,这一系列流程必然要用到很多函数,故将自己常用函数记录上.应该说这些函数基本上都会用到,像是数据预处理,处理完了后特征提取.降维.训练预测.通过混淆矩阵看分类效果,得出报告. 1.输入 从数据集开始,提取特征转化为有标签的数据集,转为向量.拆分成训练集和测试集,这里不多讲,在上一篇博客中谈到用StratifiedKFold()函数即可.在训练集中有data和target开始. 2.

  • Python利用Seaborn绘制多标签的混淆矩阵

    Seaborn - 绘制多标签的混淆矩阵.召回.精准.F1 导入seaborn\matplotlib\scipy\sklearn等包: import seaborn as sns from matplotlib import pyplot as plt from scipy.special import softmax from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_sco

  • Python利用matplotlib绘制折线图的新手教程

    前言 matplotlib是Python中的一个第三方库.主要用于开发2D图表,以渐进式.交互式的方式实现数据可视化,可以更直观的呈现数据,使数据更具说服力. 一.安装matplotlib pip install matplotlib -i https://pypi.tuna.tsinghua.edu.cn/simple 二.matplotlib图像简介 matplotlib的图像分为三层,容器层.辅助显示层和图像层. 1. 容器层主要由Canvas.Figure.Axes组成. Canvas位

  • Python利用matplotlib绘制散点图的新手教程

    前言 上篇文章介绍了使用matplotlib绘制折线图,参考:https://www.jb51.net/article/198991.htm,本篇文章继续介绍使用matplotlib绘制散点图. 一.matplotlib绘制散点图 # coding=utf-8 import matplotlib.pyplot as plt years = [2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019] turnovers =

  • Python利用D3Blocks绘制可动态交互的图表

    目录 热力图 粒子图 时间序列图 桑基图 小提琴图 散点图 弦图 网络图 今天小编给大家来介绍一款十分好用的可视化模块,D3Blocks,不仅可以用来绘制可动态交互的图表,并且导出的图表可以是HTML格式,方便在浏览器上面呈现. 热力图 热力图是一种通过对色块着色来显示数据的统计图表.绘图时需要指定颜色映射的规则.例如较大的值由较深的颜色表示,而较小的值由较浅的颜色表示等等.热力图适用于查看总体的情况,发现异常值.显示多个变量之间的差异,以及检测它们之间是否存在任何相关性. 我们这里来尝试绘制一

  • Python利用matplotlib绘制约数个数统计图示例

    本文实例讲述了Python利用matplotlib绘制约数个数统计图.分享给大家供大家参考,具体如下: 利用Python计算1000以内自然数的约数个数,然后通过matplotlib绘制统计图. 下图为约数个数的散点图及其分布情况的条形图. Python代码: import collections import matplotlib.pyplot as plt def countDivisors(num): ans = 1 x = 2 while x * x <= num: cnt = 1 wh

  • Python利用 matplotlib 绘制直方图

    目录 1. 直方图概述 1.1什么是直方图? 1.2直方图使用场景 1.3直方图绘制步骤 1.4案例展示 2. 直方图属性 2.1设置颜色 2.2设置长条形数目 2.3设置透明度 2.4设置样式 3. 添加折线直方图 4. 堆叠直方图 5. 不等距直方图 6. 多类直方图 复习回顾: 经过前面对 matplotlib 模块从底层架构.基本绘制步骤等学习,我们已经学习了折线图.柱状图的绘制方法. matplotlib 模块基础:对matplotlib 模块常用方法进行学习 matplotlib 模

  • Python利用Matplotlib绘制图表详解

    目录 前言 折线图绘制与显示 绘制数学函数图像 散点图绘制 绘制柱状图 绘制直方图 饼图 前言 Matplotlib 是 Python 中类似 MATLAB 的绘图工具,如果您熟悉 MATLAB,那么可以很快的熟悉它. Matplotlib 提供了一套面向对象绘图的 API,它可以轻松地配合 Python GUI 工具包(比如 PyQt,WxPython.Tkinter)在应用程序中嵌入图形.与此同时,它也支持以脚本的形式在 Python.IPython Shell.Jupyter Notebo

  • Python利用Turtle绘制虎年图像

    目录 导语 一.代码展示 二.效果展示 导语 2022年是农历壬寅虎年,在自然界中,虎有“百兽之王”之称 它的王者之风与勇猛,被作为威仪和权势的象征,千百年来,人们崇虎.刻虎.画虎.剪虎……形成了极具特色的中国虎文化,而今天给大家用Turtle绘制虎年图像,带给大家虎年的祝福! 虎年送头虎,全家乐悠悠,虎蹄为你开财路,虎尾为你拂忧愁. 虎耳为你撞鸿运,虎背为你驮康寿,让这头虎伴你左右,你不虎也虎 也希望大家在新年里,虎虎生威.虎年大吉 一.代码展示 本文是基于Turtle绘制的小老虎呢!本文的全

  • Python利用Turtle绘制哆啦A梦和小猪佩奇

    目录 1.哆啦A梦 2.小猪佩奇 3.Python代码实现(哆啦A梦) 4.Python代码实现(小猪佩奇 ) 1.哆啦A梦 “只要把愿望系在竹竿上请求月亮女神,心愿便能达成”.我超喜欢这句话. 哆啦A梦的创造要追溯到1969年的某个截稿日,作者藤子·F·不二雄的家里突然闯进了一只小猫,虽然很快就要截稿了,但作者还是和小猫玩了起来,还替小猫挠虱子,而这一挠就是几个小时.等作者发现时间不够用的时候,已经来不及完成稿子.这时作者像热锅上的蚂蚁走来走去,突然踢到了女儿的不倒翁玩具,于是作者灵光一现,把

  • 利用Seaborn绘制20个精美的pairplot图

    目录 参数 导入数据 默认情况 参数kind 参数hue 参数diag_kind 参数palette 参数markers 参数height 参数aspect 参数corner 参数vars 参数-x_vars/y_vars 参数-plot_kws/diag_kws 参数-dropna 返回值-PairGrid 大家好,我是Peter~ 本文记录的使用seaborn绘制pairplot图,主要是用来显示两两变量之间的关系(线性或非线性,有无较为明显的相关关系等),官网学习地址: https://s

随机推荐