Python数据相关系数矩阵和热力图轻松实现教程

对其中的参数进行解释

plt.subplots(figsize=(9, 9))设置画面大小,会使得整个画面等比例放大的

sns.heapmap()这个当然是用来生成热力图的啦

df是DataFrame, pandas的这个类还是很常用的啦~

df.corr()就是得到这个dataframe的相关系数矩阵

把这个矩阵直接丢给sns.heapmap中做参数就好啦

sns.heapmap中annot=True,意思是显式热力图上的数值大小。

sns.heapmap中square=True,意思是将图变成一个正方形,默认是一个矩形

sns.heapmap中cmap="Blues"是一种模式,就是图颜色配置方案啦,我很喜欢这一款的。

sns.heapmap中vmax是显示最大值

import seaborn as sns
import matplotlib.pyplot as plt
def test(df):
 dfData = df.corr()
 plt.subplots(figsize=(9, 9)) # 设置画面大小
 sns.heatmap(dfData, annot=True, vmax=1, square=True, cmap="Blues")
 plt.savefig('./BluesStateRelation.png')
 plt.show()

补充知识:python混淆矩阵(confusion_matrix)FP、FN、TP、TN、ROC,精确率(Precision),召回率(Recall),准确率(Accuracy)详述与实现

一、FP、FN、TP、TN

你这蠢货,是不是又把酸葡萄和葡萄酸弄“混淆“”啦!!!

上面日常情况中的混淆就是:是否把某两件东西或者多件东西给弄混了,迷糊了。

在机器学习中, 混淆矩阵是一个误差矩阵, 常用来可视化地评估监督学习算法的性能.。混淆矩阵大小为 (n_classes, n_classes) 的方阵, 其中 n_classes 表示类的数量。

其中,这个矩阵的一行表示预测类中的实例(可以理解为模型预测输出,predict),另一列表示对该预测结果与标签(Ground Truth)进行判定模型的预测结果是否正确,正确为True,反之为False。

在机器学习中ground truth表示有监督学习的训练集的分类准确性,用于证明或者推翻某个假设。有监督的机器学习会对训练数据打标记,试想一下如果训练标记错误,那么将会对测试数据的预测产生影响,因此这里将那些正确打标记的数据成为ground truth。

此时,就引入FP、FN、TP、TN与精确率(Precision),召回率(Recall),准确率(Accuracy)。

以猫狗二分类为例,假定cat为正例-Positive,dog为负例-Negative;预测正确为True,反之为False。我们就可以得到下面这样一个表示FP、FN、TP、TN的表:

此时如下代码所示,其中scikit-learn 混淆矩阵函数 sklearn.metrics.confusion_matrix API 接口,可以用于绘制混淆矩阵

skearn.metrics.confusion_matrix(
 y_true, # array, Gound true (correct) target values
 y_pred, # array, Estimated targets as returned by a classifier
 labels=None, # array, List of labels to index the matrix.
 sample_weight=None # array-like of shape = [n_samples], Optional sample weights
)

完整示例代码如下:

__author__ = "lingjun"
# welcome to attention:小白CV

import seaborn as sns
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
sns.set()

f, (ax1,ax2) = plt.subplots(figsize = (10, 8),nrows=2)
y_true = ["dog", "dog", "dog", "cat", "cat", "cat", "cat"]
y_pred = ["cat", "cat", "dog", "cat", "cat", "cat", "cat"]
C2= confusion_matrix(y_true, y_pred, labels=["dog", "cat"])
print(C2)
print(C2.ravel())
sns.heatmap(C2,annot=True)

ax2.set_title('sns_heatmap_confusion_matrix')
ax2.set_xlabel('Pred')
ax2.set_ylabel('True')
f.savefig('sns_heatmap_confusion_matrix.jpg', bbox_inches='tight')

保存的图像如下所示:

这个时候我们还是不知道skearn.metrics.confusion_matrix做了些什么,这个时候print(C2),打印看下C2究竟里面包含着什么。最终的打印结果如下所示:

[[1 2]
 [0 4]]
[1 2 0 4]

解释下上面这几个数字的意思:

C2= confusion_matrix(y_true, y_pred, labels=["dog", "cat"])中的labels的顺序就分布是0、1,negative和positive

注:labels=[]可加可不加,不加情况下会自动识别,自己定义

cat为1-positive,其中真实值中cat有4个,4个被预测为cat,预测正确T,0个被预测为dog,预测错误F;

dog为0-negative,其中真实值中dog有3个,1个被预测为dog,预测正确T,2个被预测为cat,预测错误F。

所以:TN=1、 FP=2 、FN=0、TP=4。

TN=1:预测为negative狗中1个被预测正确了

FP=2 :预测为positive猫中2个被预测错误了

FN=0:预测为negative狗中0个被预测错误了

TP=4:预测为positive猫中4个被预测正确了

这时候再把上面猫狗预测结果拿来看看,6个被预测为cat,但是只有4个的true是cat,此时就和右侧的红圈对应上了。

y_pred = ["cat", "cat", "dog", "cat", "cat", "cat", "cat"]
y_true = ["dog", "dog", "dog", "cat", "cat", "cat", "cat"]

二、精确率(Precision),召回率(Recall),准确率(Accuracy)

有了上面的这些数值,就可以进行如下的计算工作了

准确率(Accuracy):这三个指标里最直观的就是准确率: 模型判断正确的数据(TP+TN)占总数据的比例

"Accuracy: "+str(round((tp+tn)/(tp+fp+fn+tn), 3))

召回率(Recall): 针对数据集中的所有正例label(TP+FN)而言,模型正确判断出的正例(TP)占数据集中所有正例的比例;FN表示被模型误认为是负例但实际是正例的数据;召回率也叫查全率,以物体检测为例,我们往往把图片中的物体作为正例,此时召回率高代表着模型可以找出图片中更多的物体!

"Recall: "+str(round((tp)/(tp+fn), 3))

精确率(Precision):针对模型判断出的所有正例(TP+FP)而言,其中真正例(TP)占的比例。精确率也叫查准率,还是以物体检测为例,精确率高表示模型检测出的物体中大部分确实是物体,只有少量不是物体的对象被当成物体。

"Precision: "+str(round((tp)/(tp+fp), 3))

还有:

("Sensitivity: "+str(round(tp/(tp+fn+0.01), 3)))
("Specificity: "+str(round(1-(fp/(fp+tn+0.01)), 3)))
("False positive rate: "+str(round(fp/(fp+tn+0.01), 3)))
("Positive predictive value: "+str(round(tp/(tp+fp+0.01), 3)))
("Negative predictive value: "+str(round(tn/(fn+tn+0.01), 3)))

三.绘制ROC曲线,及计算以上评价参数

如下为统计数据:

__author__ = "lingjun"
# E-mail: 1763469890@qq.com

from sklearn.metrics import roc_auc_score, confusion_matrix, roc_curve, auc
from matplotlib import pyplot as plt
import numpy as np
import torch
import csv

def confusion_matrix_roc(GT, PD, experiment, n_class):
 GT = GT.numpy()
 PD = PD.numpy()

 y_gt = np.argmax(GT, 1)
 y_gt = np.reshape(y_gt, [-1])
 y_pd = np.argmax(PD, 1)
 y_pd = np.reshape(y_pd, [-1])

 # ---- Confusion Matrix and Other Statistic Information ----
 if n_class > 2:
  c_matrix = confusion_matrix(y_gt, y_pd)
  # print("Confussion Matrix:\n", c_matrix)
  list_cfs_mtrx = c_matrix.tolist()
  # print("List", type(list_cfs_mtrx[0]))

  path_confusion = r"./records/" + experiment + "/confusion_matrix.txt"
  # np.savetxt(path_confusion, (c_matrix))
  np.savetxt(path_confusion, np.reshape(list_cfs_mtrx, -1), delimiter=',', fmt='%5s')

 if n_class == 2:
  list_cfs_mtrx = []
  tn, fp, fn, tp = confusion_matrix(y_gt, y_pd).ravel()

  list_cfs_mtrx.append("TN: " + str(tn))
  list_cfs_mtrx.append("FP: " + str(fp))
  list_cfs_mtrx.append("FN: " + str(fn))
  list_cfs_mtrx.append("TP: " + str(tp))
  list_cfs_mtrx.append(" ")
  list_cfs_mtrx.append("Accuracy: " + str(round((tp + tn) / (tp + fp + fn + tn), 3)))
  list_cfs_mtrx.append("Sensitivity: " + str(round(tp / (tp + fn + 0.01), 3)))
  list_cfs_mtrx.append("Specificity: " + str(round(1 - (fp / (fp + tn + 0.01)), 3)))
  list_cfs_mtrx.append("False positive rate: " + str(round(fp / (fp + tn + 0.01), 3)))
  list_cfs_mtrx.append("Positive predictive value: " + str(round(tp / (tp + fp + 0.01), 3)))
  list_cfs_mtrx.append("Negative predictive value: " + str(round(tn / (fn + tn + 0.01), 3)))

  path_confusion = r"./records/" + experiment + "/confusion_matrix.txt"
  np.savetxt(path_confusion, np.reshape(list_cfs_mtrx, -1), delimiter=',', fmt='%5s')

 # ---- ROC ----
 plt.figure(1)
 plt.figure(figsize=(6, 6))

 fpr, tpr, thresholds = roc_curve(GT[:, 1], PD[:, 1])
 roc_auc = auc(fpr, tpr)

 plt.plot(fpr, tpr, lw=1, label="ATB vs NotTB, area=%0.3f)" % (roc_auc))
 # plt.plot(thresholds, tpr, lw=1, label='Thr%d area=%0.2f)' % (1, roc_auc))
 # plt.plot([0, 1], [0, 1], '--', color=(0.6, 0.6, 0.6), label='Luck')

 plt.xlim([0.00, 1.0])
 plt.ylim([0.00, 1.0])
 plt.xlabel("False Positive Rate")
 plt.ylabel("True Positive Rate")
 plt.title("ROC")
 plt.legend(loc="lower right")
 plt.savefig(r"./records/" + experiment + "/ROC.png")
 print("ok")

def inference():
 GT = torch.FloatTensor()
 PD = torch.FloatTensor()
 file = r"Sensitive_rename_inform.csv"
 with open(file, 'r', encoding='UTF-8') as f:
  reader = csv.DictReader(f)
  for row in reader:
   # TODO
   max_patient_score = float(row['ai1'])
   doctor_gt = row['gt2']

   print(max_patient_score,doctor_gt)

   pd = [[max_patient_score, 1-max_patient_score]]
   output_pd = torch.FloatTensor(pd).to(device)

   if doctor_gt == "+":
    target = [[1.0, 0.0]]
   else:
    target = [[0.0, 1.0]]
   target = torch.FloatTensor(target) # 类型转换, 将list转化为tensor, torch.FloatTensor([1,2])
   Target = torch.autograd.Variable(target).long().to(device)

   GT = torch.cat((GT, Target.float().cpu()), 0) # 在行上进行堆叠
   PD = torch.cat((PD, output_pd.float().cpu()), 0)

 confusion_matrix_roc(GT, PD, "ROC", 2)

if __name__ == "__main__":
 inference()

若是表格里面有中文,则记得这里进行修改,否则报错

with open(file, 'r') as f:

以上这篇Python数据相关系数矩阵和热力图轻松实现教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • python 计算两个列表的相关系数的实现

    用pandas计算相关系数 计算相关系数用pandas,比如我想知道风速大小与风向紊乱(标准差来衡量)之间的相关系数,下面是代码: import pandas as pd import pylab as plt #每小时的阵风风速平均值 all_gust_spd_mean_list = [8.21529411764706, 7.872941176470587, 7.829411764705882, 8.354117647058825, 9.025882352941174, 9.384523809

  • Python绘制热力图示例

    本文实例讲述了Python绘制热力图操作.分享给大家供大家参考,具体如下: 示例一: # -*- coding: utf-8 -*- from pyheatmap.heatmap import HeatMap import numpy as np N = 10000 X = np.random.rand(N) * 255 # [0, 255] Y = np.random.rand(N) * 255 data = [] for i in range(N): tmp = [int(X[i]), in

  • Python+pandas计算数据相关系数的实例

    本文主要演示pandas中DataFrame对象corr()方法的用法,该方法用来计算DataFrame对象中所有列之间的相关系数(包括pearson相关系数.Kendall Tau相关系数和spearman秩相关). >>> import numpy as np >>> import pandas as pd >>> df = pd.DataFrame({'A':np.random.randint(1, 100, 10), 'B':np.random

  • python绘制热力图heatmap

    本文实例为大家分享了python绘制热力图的具体代码,供大家参考,具体内容如下 python的热力图是用皮尔逊相关系数来查看两者之间的关联性. #encoding:utf-8 import numpy as np import pandas as pd from matplotlib import pyplot as plt from matplotlib import cm from matplotlib import axes import pylab pylab.mpl.rcParams[

  • Python数据相关系数矩阵和热力图轻松实现教程

    对其中的参数进行解释 plt.subplots(figsize=(9, 9))设置画面大小,会使得整个画面等比例放大的 sns.heapmap()这个当然是用来生成热力图的啦 df是DataFrame, pandas的这个类还是很常用的啦~ df.corr()就是得到这个dataframe的相关系数矩阵 把这个矩阵直接丢给sns.heapmap中做参数就好啦 sns.heapmap中annot=True,意思是显式热力图上的数值大小. sns.heapmap中square=True,意思是将图变

  • python数据可视化Seaborn画热力图

    目录 1.引言 2. 栗子 3. 数据预处理 4. 画热力图 5. 添加数值 6. 调色板优化 1.引言 热力图的想法很简单,用颜色替换数字. 现在,这种可视化风格已经从最初的颜色编码表格走了很长一段路.热力图被广泛用于地理空间数据.这种图通常用于描述变量的密度或强度,模式可视化.方差甚至异常可视化等. 鉴于热力图有如此多的应用,本文将介绍如何使用Seaborn 来创建热力图. 2. 栗子 首先我们导入Pandas和Numpy库,这两个库可以帮助我们进行数据预处理. import pandas

  • Python数据可视化Pyecharts制作Heatmap热力图

    目录 HeatMap:热力图 1.基本设置 2.热力图数据项 Demo 举例 1.基础热力图 本文介绍基于 Python3 的 Pyecharts 制作 Heatmap(热力图 时需要使用的设置参数和常用模板案例,可根据实际情况对案例中的内容进行调整即可. 使用 Pyecharts 进行数据可视化时可提供直观.交互丰富.可高度个性化定制的数据可视化图表.案例中的代码内容基于 Pyecharts 1.x 版本 . HeatMap:热力图 1.基本设置 class HeatMap( # 初始化配置项

  • Python数据可视化之基于pyecharts实现的地理图表的绘制

    一.例子:百度迁徙 百度地图春节人口迁徙大数据(简称百度迁徙),是百度在2014年春运期间推出的一项技术项目.百度迁徙利用大数据,对其拥有的LBS(基于地理位置的服务)大数据进行计算分析,采用的可视化呈现方式,动态.即时.直观地展现中国春节前后人口大迁徙的轨迹与特征. 网址:https://qianxi.baidu.com/2021/ 二.基础语法介绍 语法 说明 from pyecharts.charts import Geo 导入地图库 Geo() Pyecharts地理图表绘制 .add_

  • 学会Python数据可视化必须尝试这7个库

    目录 一.Seaborn 二.Plotly 三.Geoplotlib 四.Gleam 五.ggplot 六.Bokeh 七.Missingo 一.Seaborn Seaborn 建于 matplotlib 库的之上.它有许多内置函数,使用这些函数,只需简单的代码行就可以创建漂亮的绘图.它提供了多种高级的可视化绘图和简单的语法,如方框图.小提琴图.距离图.关节图.成对图.热图等. 安装 ip install seaborn 主要特征: 可用于确定两个变量之间的关系. 在分析单变量或双变量分布时进行

  • python数据可视化JupyterLab实用扩展程序Mito

    目录 遇见 Mito 如何启动 Mito 数据透视表 Mito 令人印象深刻的功能 可视化数据 自动代码生成 Mito 安装 JupyterLab 是 Jupyter 主打的最新数据科学生产工具,某种意义上,它的出现是为了取代Jupyter Notebook. 它作为一种基于 web 的集成开发环境,你可以使用它编写notebook.操作终端.编辑markdown文本.打开交互模式.查看csv文件及图片等功能. JupyterLab 最棒的体验就是有丰富的扩展插件,我记得过去我们不得不依赖 nu

  • Python数据可视化Pyecharts库的使用教程

    目录 一.Pyecharts 概述 1.1 Pyecharts 特性 1.2 Pyecharts 入门案例 二.Pyecharts 配置项 2.1 全局配置项 2.2 系列配置项 三.Pyecharts 的总结 一.Pyecharts 概述 Pyechart 是一个用于生成 Echarts 图表(Echarts 是基于 Javascript 的开源可视化图表库)的 Python 第三方库. 1.1 Pyecharts 特性 根据官方文档的介绍,Pyecharts 的特性如下: 1.简洁的 API

  • 详解Python+Matplotlib绘制面积图&热力图

    目录 1.绘制面积图 2.绘制热力图 1.绘制面积图 面积图常用于描述某指标随时间的变化程度.其面积也通常可以有一定的含义. 绘制面积图使用的是plt.stackplot()方法. 以小学时期学的 常见的追击相遇问题中的速度时间图像为例,下边绘制出一幅简单的v-t图像. 全局字体设为默认的黑体,时间为从第0秒到第10秒,描述的是甲乙两个物体的速度.显然,面积则表示位移. 标题部分字体使用楷体(将系统中的TTF字体文件"STKAITI.TTF"复制到了当前目录下). import mat

  • Python数据可视化之Seaborn的使用详解

    目录 1. 安装 seaborn 2.准备数据 3.背景与边框 3.1 设置背景风格 3.2 其他 3.3 边框控制 4. 绘制 散点图 5. 绘制 折线图 5.1 使用 replot()方法 5.2 使用 lineplot()方法 6. 绘制直方图 displot() 7. 绘制条形图 barplot() 8. 绘制线性回归模型 9. 绘制 核密度图 kdeplot() 9.1 一般核密度图 9.2 边际核密度图 10. 绘制 箱线图 boxplot() 11. 绘制 提琴图 violinpl

  • Python数据可视化之Pyecharts使用详解

    目录 1. 安装Pyecharts 2. 图表基础 2.1 主题风格 2.2 图表标题 2.3 图例 2.4 提示框 2.5 视觉映射 2.6 工具箱 2.7 区域缩放 3. 柱状图 Bar模块 4. 折线图/面积图 Line模块 4.1 折线图 4.2 面积图 5.饼形图 5.1 饼形图 5.2 南丁格尔玫瑰图 6. 箱线图 Boxplot模块 7. 涟漪特效散点图 EffectScatter模块 8. 词云图 WordCloud模块 9. 热力图 HeatMap模块 10. 水球图 Liqu

随机推荐