利用python中的matplotlib打印混淆矩阵实例

前面说过混淆矩阵是我们在处理分类问题时,很重要的指标,那么如何更好的把混淆矩阵给打印出来呢,直接做表或者是前端可视化,小编曾经就尝试过用前端(D5)做出来,然后截图,显得不那么好看。。

代码:

import itertools
import matplotlib.pyplot as plt
import numpy as np

def plot_confusion_matrix(cm, classes,
       normalize=False,
       title='Confusion matrix',
       cmap=plt.cm.Blues):
 """
 This function prints and plots the confusion matrix.
 Normalization can be applied by setting `normalize=True`.
 """
 if normalize:
  cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
  print("Normalized confusion matrix")
 else:
  print('Confusion matrix, without normalization')

 print(cm)

 plt.imshow(cm, interpolation='nearest', cmap=cmap)
 plt.title(title)
 plt.colorbar()
 tick_marks = np.arange(len(classes))
 plt.xticks(tick_marks, classes, rotation=45)
 plt.yticks(tick_marks, classes)

 fmt = '.2f' if normalize else 'd'
 thresh = cm.max() / 2.
 for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
  plt.text(j, i, format(cm[i, j], fmt),
     horizontalalignment="center",
     color="white" if cm[i, j] > thresh else "black")

 plt.tight_layout()
 plt.ylabel('True label')
 plt.xlabel('Predicted label')
 plt.show()
 # plt.savefig('confusion_matrix',dpi=200)

cnf_matrix = np.array([
 [4101, 2, 5, 24, 0],
 [50, 3930, 6, 14, 5],
 [29, 3, 3973, 4, 0],
 [45, 7, 1, 3878, 119],
 [31, 1, 8, 28, 3936],
])

class_names = ['Buildings', 'Farmland', 'Greenbelt', 'Wasteland', 'Water']

# plt.figure()
# plot_confusion_matrix(cnf_matrix, classes=class_names,
#      title='Confusion matrix, without normalization')

# Plot normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
      title='Normalized confusion matrix')

在放矩阵位置,放一下你的混淆矩阵就可以,当然可视化混淆矩阵这一步也可以直接在模型运行中完成。

补充知识:混淆矩阵(Confusion matrix)的原理及使用(scikit-learn 和 tensorflow)

原理

在机器学习中, 混淆矩阵是一个误差矩阵, 常用来可视化地评估监督学习算法的性能. 混淆矩阵大小为 (n_classes, n_classes) 的方阵, 其中 n_classes 表示类的数量. 这个矩阵的每一行表示真实类中的实例, 而每一列表示预测类中的实例 (Tensorflow 和 scikit-learn 采用的实现方式). 也可以是, 每一行表示预测类中的实例, 而每一列表示真实类中的实例 (Confusion matrix From Wikipedia 中的定义). 通过混淆矩阵, 可以很容易看出系统是否会弄混两个类, 这也是混淆矩阵名字的由来.

混淆矩阵是一种特殊类型的列联表(contingency table)或交叉制表(cross tabulation or crosstab). 其有两维 (真实值 "actual" 和 预测值 "predicted" ), 这两维都具有相同的类("classes")的集合. 在列联表中, 每个维度和类的组合是一个变量. 列联表以表的形式, 可视化地表示多个变量的频率分布.

使用混淆矩阵( scikit-learn 和 Tensorflow)

下面先介绍在 scikit-learn 和 tensorflow 中计算混淆矩阵的 API (Application Programming Interface) 接口函数, 然后在一个示例中, 使用这两个 API 函数.

scikit-learn 混淆矩阵函数 sklearn.metrics.confusion_matrix API 接口

skearn.metrics.confusion_matrix(
 y_true, # array, Gound true (correct) target values
 y_pred, # array, Estimated targets as returned by a classifier
 labels=None, # array, List of labels to index the matrix.
 sample_weight=None # array-like of shape = [n_samples], Optional sample weights
)

在 scikit-learn 中, 计算混淆矩阵用来评估分类的准确度.

按照定义, 混淆矩阵 C 中的元素 Ci,j 等于真实值为组 i , 而预测为组 j 的观测数(the number of observations). 所以对于二分类任务, 预测结果中, 正确的负例数(true negatives, TN)为 C0,0; 错误的负例数(false negatives, FN)为 C1,0; 真实的正例数为 C1,1; 错误的正例数为 C0,1.

如果 labels 为 None, scikit-learn 会把在出现在 y_true 或 y_pred 中的所有值添加到标记列表 labels 中, 并排好序.

Tensorflow 混淆矩阵函数 tf.confusion_matrix API 接口

tf.confusion_matrix(
 labels, # 1-D Tensor of real labels for the classification task
 predictions, # 1-D Tensor of predictions for a givenclassification
 num_classes=None, # The possible number of labels the classification task can have
 dtype=tf.int32, # Data type of the confusion matrix
 name=None, # Scope name
 weights=None, # An optional Tensor whose shape matches predictions
)

Tensorflow tf.confusion_matrix 中的 num_classes 参数的含义, 与 scikit-learn sklearn.metrics.confusion_matrix 中的 labels 参数相近, 是与标记有关的参数, 表示类的总个数, 但没有列出具体的标记值. 在 Tensorflow 中一般是以整数作为标记, 如果标记为字符串等非整数类型, 则需先转为整数表示. 如果 num_classes 参数为 None, 则把 labels 和 predictions 中的最大值 + 1, 作为num_classes 参数值.

tf.confusion_matrix 的 weights 参数和 sklearn.metrics.confusion_matrix 的 sample_weight 参数的含义相同, 都是对预测值进行加权, 在此基础上, 计算混淆矩阵单元的值.

使用示例

#!/usr/bin/env python
# -*- coding: utf8 -*-
"""
Author: klchang
Description:
  A simple example for tf.confusion_matrix and sklearn.metrics.confusion_matrix.
Date: 2018.9.8
"""
from __future__ import print_function
import tensorflow as tf
import sklearn.metrics

y_true = [1, 2, 4]
y_pred = [2, 2, 4]

# Build graph with tf.confusion_matrix operation
sess = tf.InteractiveSession()
op = tf.confusion_matrix(y_true, y_pred)
op2 = tf.confusion_matrix(y_true, y_pred, num_classes=6, dtype=tf.float32, weights=tf.constant([0.3, 0.4, 0.3]))
# Execute the graph
print ("confusion matrix in tensorflow: ")
print ("1. default: \n", op.eval())
print ("2. customed: \n", sess.run(op2))
sess.close()

# Use sklearn.metrics.confusion_matrix function
print ("\nconfusion matrix in scikit-learn: ")
print ("1. default: \n", sklearn.metrics.confusion_matrix(y_true, y_pred))
print ("2. customed: \n", sklearn.metrics.confusion_matrix(y_true, y_pred, labels=range(6), sample_weight=[0.3, 0.4, 0.3]))

以上这篇利用python中的matplotlib打印混淆矩阵实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • 详解使用python绘制混淆矩阵(confusion_matrix)

    Summary 涉及到分类问题,我们经常需要通过可视化混淆矩阵来分析实验结果进而得出调参思路,本文介绍如何利用python绘制混淆矩阵(confusion_matrix),本文只提供代码,给出必要注释. Code​ # -*-coding:utf-8-*- from sklearn.metrics import confusion_matrix import matplotlib.pyplot as plt import numpy as np #labels表示你不同类别的代号,比如这里的de

  • 解决python中用matplotlib画多幅图时出现图形部分重叠的问题

    1.解决方法:使用函数 tight_layout() 2.具体使用方法 import matplotlib.pyplot as plt fig = plt.figure() ''' 具体的画图程序 ''' fig.tight_layout() fig.tight_layout() 功能:使得子图横纵坐标更加紧凑,主要用于自动调整图区的大小以及间距,使所有的绘图及其标题.坐标轴标签等都可以不重叠的完整显示在画布上. 参数: Pad:用于设置绘图区边缘与画布边缘的距离大小 w_pad:用于设置绘图区

  • python sklearn包——混淆矩阵、分类报告等自动生成方式

    preface:做着最近的任务,对数据处理,做些简单的提特征,用机器学习算法跑下程序得出结果,看看哪些特征的组合较好,这一系列流程必然要用到很多函数,故将自己常用函数记录上.应该说这些函数基本上都会用到,像是数据预处理,处理完了后特征提取.降维.训练预测.通过混淆矩阵看分类效果,得出报告. 1.输入 从数据集开始,提取特征转化为有标签的数据集,转为向量.拆分成训练集和测试集,这里不多讲,在上一篇博客中谈到用StratifiedKFold()函数即可.在训练集中有data和target开始. 2.

  • Python matplotlib可视化实例解析

    例1 使用Python+matplotlib绘图进行可视化,在图形中创建轴域并设置轴域的位置和大小,同时演示设置坐标轴标签和图例位置的用法. 参考代码: 运行结果: 例2 绘制正线余弦图像,然后设置图例字体.标题.位置.阴影.背景色.边框颜色.分栏.符号位置等属性. 运行效果: 例3 生成模拟数据,创建两个子图,分别绘制正弦曲线和余弦曲线,把两个子图的图例显示在一起,并显示于子图之外. 运行效果: 例4 生成模拟数据,绘制正弦曲线.余弦曲线和两个散点图,然后分别为曲线和散点图设置图例,在一个图形

  • 利用python中的matplotlib打印混淆矩阵实例

    前面说过混淆矩阵是我们在处理分类问题时,很重要的指标,那么如何更好的把混淆矩阵给打印出来呢,直接做表或者是前端可视化,小编曾经就尝试过用前端(D5)做出来,然后截图,显得不那么好看.. 代码: import itertools import matplotlib.pyplot as plt import numpy as np def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cma

  • 利用Python中unittest实现简单的单元测试实例详解

    前言 单元测试的重要性就不多说了,可恶的是Python中有太多的单元测试框架和工具,什么unittest, testtools, subunit, coverage, testrepository, nose, mox, mock, fixtures, discover,再加上setuptools, distutils等等这些,先不说如何写单元测试,光是怎么运行单元测试就有N多种方法,再因为它是测试而非功能,是很多人没兴趣触及的东西.但是作为一个优秀的程序员,不仅要写好功能代码,写好测试代码一样

  • Python利用Seaborn绘制多标签的混淆矩阵

    Seaborn - 绘制多标签的混淆矩阵.召回.精准.F1 导入seaborn\matplotlib\scipy\sklearn等包: import seaborn as sns from matplotlib import pyplot as plt from scipy.special import softmax from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_sco

  • 利用Python中的Xpath实现一个在线汇率转换器

    前言 在之前的语法里面,我们记得有一个初识Python之汇率转换篇,在那个程序里面我们发现可以运用一些基础的语法写一个汇率计算,但是学到后面的小伙伴就会发现这个小程序有一定的弊端. 首先,它不可以实时的获取汇率的值,每次都需要我们自己去定义一个汇率转换值,这个就会显得不是很智能,有点机械,所以我们这一个利用爬虫爬取一个网址里面的汇率值(一直在更新的),这里我们利用Xpath来获取这个数据值 其次我们发现在之前的程序里面,我们好像只能输入两位数的货币数据,这一次我们通过正负索引的方法,只获取除了最

  • Matplotlib绘制混淆矩阵的实现

    对于机器学习多分类模型来说,其评价指标除了精度之外,常用的还有混淆矩阵和分类报告,下面来展示一下如何绘制混淆矩阵,这在论文中经常会用到. 代码如下: import itertools import matplotlib.pyplot as plt import numpy as np # 绘制混淆矩阵 def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blue

  • matplotlib画混淆矩阵与正确率曲线的实例代码

    混淆矩阵 混淆矩阵(Confusion Matrix)是机器学习中用来总结分类模型预测结果的一个分析表,是模式识别领域中的一种常用的表达形式.它以矩阵的形式描绘样本数据的真实属性和分类预测结果类型之间的关系,是用来评价分类器性能的一种常用方法. 我们可以通过一个简单的例子来直观理解混淆矩阵 #!/usr/bin/python3.5 # -*- coding: utf-8 -*- import numpy as np import matplotlib.pyplot as plt plt.rcPa

  • Python中使用matplotlib绘制mqtt数据实时图像功能

    目录 效果图 mqtt发布 mqtt订阅 matplotlib绘制动态图 matplotlib绘制mqtt数据实时图像 效果图 mqtt发布 本代码中publish是一个死循环,数据一直往外发送. import random import time from paho.mqtt import client as mqtt_client import json from datetime import datetime broker = 'broker.emqx.io' port = 1883 t

  • 利用python中pymysql操作MySQL数据库的新手指南

    目录 一. pymysql介绍 二. 连接数据库的完整流程 1. 引入pymysql模块 2. 创建连接对象 3. 使用连接对象创建游标对象 4. 准备需要使用的sql语句 5. 使用游标对象执行sql语句(如果是数据修改的操作,会返回受影响的行数) 6. 如果执行语句是查询操作,需要使用游标对象获取查询结果 7. 关闭游标对象 8. 关闭连接对象 三. 完整的简易源码 总结 一. pymysql介绍 pymysql 是在 Python3.x 版本中用于连接和操作 MySQL 服务器的一个库.

  • 利用Python中的pandas库对cdn日志进行分析详解

    前言 最近工作工作中遇到一个需求,是要根据CDN日志过滤一些数据,例如流量.状态码统计,TOP IP.URL.UA.Referer等.以前都是用 bash shell 实现的,但是当日志量较大,日志文件数G.行数达数千万亿级时,通过 shell 处理有些力不从心,处理时间过长.于是研究了下Python pandas这个数据处理库的使用.一千万行日志,处理完成在40s左右. 代码 #!/usr/bin/python # -*- coding: utf-8 -*- # sudo pip instal

  • 用python中的matplotlib绘制方程图像代码

    import numpy as np import matplotlib.pyplot as plt def main(): # 设置x和y的坐标范围 x=np.arange(-2,2,0.01) y=np.arange(-2,2,0.01) # 转化为网格 x,y=np.meshgrid(x,y) z=np.power(x,2)+np.power(y,2)-1 plt.contour(x,y,z,0) plt.show() main() 绘制的时候要保证x,y,z的维度相同 结果如下: 以上这

随机推荐