sklearn的predict_proba使用说明

发现个很有用的方法——predict_proba

今天在做数据预测的时候用到了,感觉很不错,所以记录分享一下,以后可能会经常用到。

我的理解:predict_proba不同于predict,它返回的预测值为,获得所有结果的概率。(有多少个分类结果,每行就有多少个概率,以至于它对每个结果都有一个可能,如0、1就有两个概率)

举例:

获取数据及预测代码:

from sklearn.linear_model import LogisticRegression
import numpy as np

train_X = np.array(np.random.randint(0,10,size=30).reshape(10,3))
train_y = np.array(np.random.randint(0,2,size=10))
test_X = np.array(np.random.randint(0,10,size=12).reshape(4,3))

model = LogisticRegression()
model.fit(train_X,train_y)
test_y = model.predict_proba(test_X)

print(train_X)
print(train_y)
print(test_y)

训练数据

[[2 9 8]
 [0 8 5]
 [7 1 2]
 [8 4 6]
 [8 8 3]
 [7 2 7]
 [6 4 3]
 [1 4 4]
 [1 9 3]
 [3 4 7]]

训练结果,与训练数据一一对应:

[1 1 1 0 1 1 0 0 0 1]

测试数据:

[[4 3 0]  #测试数据
 [3 0 4]
 [2 9 5]
 [2 8 5]]

测试结果,与测试数据一一对应:

[[0.48753831 0.51246169]
 [0.58182694 0.41817306]
 [0.85361393 0.14638607]
 [0.57018655 0.42981345]]

可以看出,有四行两列,每行对应一条预测数据,两列分别对应 对于0、1的预测概率(左边概率大于0.5则为0,反之为1)

我们来看看使用predict方法获得的结果:

test_y = model.predict(test_X)
print(test_y)

输出结果:[1,0,0,0]

所以有的情况下predict_proba还是很有用的,它可以获得对每种可能结果的概率,使用predict则是直接获得唯一的预测结果,所以在使用的时候,应该灵活使用。

补充一个知识点:关于预测结果标签如何与原来标签相对应

predict_proba返回所有标签值可能性概率值,这些值是如何排序的呢?

返回模型中每个类的样本概率,其中类按类self.classes_进行排序。

其中关键的步骤为numpy的unique方法,即通过np.unique(Label)方法,对Label中的所有标签值进行从小到大的去重排序。得到一个从小到大唯一值的排序。这也就对应于predict_proba的行返回结果。

补充知识: python sklearn decision_function、predict_proba、predict

看代码~

import matplotlib.pyplot as plt
import numpy as np
from sklearn.svm import SVC
X = np.array([[-1,-1],[-2,-1],[1,1],[2,1],[-1,1],[-1,2],[1,-1],[1,-2]])
y = np.array([0,0,1,1,2,2,3,3])
# y=np.array([1,1,2,2,3,3,4,4])
# clf = SVC(decision_function_shape="ovr",probability=True)
clf = SVC(probability=True)
clf.fit(X, y)
print(clf.decision_function(X))
'''
对于n分类,会有n个分类器,然后,任意两个分类器都可以算出一个分类界面,这样,用decision_function()时,对于任意一个样例,就会有n*(n-1)/2个值。
任意两个分类器可以算出一个分类界面,然后这个值就是距离分类界面的距离。
我想,这个函数是为了统计画图,对于二分类时最明显,用来统计每个点离超平面有多远,为了在空间中直观的表示数据以及画超平面还有间隔平面等。
decision_function_shape="ovr"时是4个值,为ovo时是6个值。
'''
print(clf.predict(X))
clf.predict_proba(X) #这个是得分,每个分类器的得分,取最大得分对应的类。
#画图
plot_step=0.02
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
           np.arange(y_min, y_max, plot_step))

Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) #对坐标风格上的点进行预测,来画分界面。其实最终看到的类的分界线就是分界面的边界线。
Z = Z.reshape(xx.shape)
cs = plt.contourf(xx, yy, Z, cmap=plt.cm.Paired)
plt.axis("tight")

class_names="ABCD"
plot_colors="rybg"
for i, n, c in zip(range(4), class_names, plot_colors):
  idx = np.where(y == i) #i为0或者1,两个类
  plt.scatter(X[idx, 0], X[idx, 1],
        c=c, cmap=plt.cm.Paired,
        label="Class %s" % n)
plt.xlim(x_min, x_max)
plt.ylim(y_min, y_max)
plt.legend(loc='upper right')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Decision Boundary')
plt.show()

以上这篇sklearn的predict_proba使用说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • python sklearn库实现简单逻辑回归的实例代码

    Sklearn简介 Scikit-learn(sklearn)是机器学习中常用的第三方模块,对常用的机器学习方法进行了封装,包括回归(Regression).降维(Dimensionality Reduction).分类(Classfication).聚类(Clustering)等方法.当我们面临机器学习问题时,便可根据下图来选择相应的方法. Sklearn具有以下特点: 简单高效的数据挖掘和数据分析工具 让每个人能够在复杂环境中重复使用 建立NumPy.Scipy.MatPlotLib之上 代

  • Python机器学习库scikit-learn安装与基本使用教程

    本文实例讲述了Python机器学习库scikit-learn安装与基本使用.分享给大家供大家参考,具体如下: 引言 scikit-learn是Python的一个开源机器学习模块,它建立在NumPy,SciPy和matplotlib模块之上能够为用户提供各种机器学习算法接口,可以让用户简单.高效地进行数据挖掘和数据分析. scikit-learn安装 python 中安装许多模板库之前都有依赖关系,安装 scikit-learn 之前需要以下先决条件: Python(>= 2.6 or >= 3

  • 基于sklearn实现Bagging算法(python)

    本文使用的数据类型是数值型,每一个样本6个特征表示,所用的数据如图所示: 图中A,B,C,D,E,F列表示六个特征,G表示样本标签.每一行数据即为一个样本的六个特征和标签. 实现Bagging算法的代码如下: from sklearn.ensemble import BaggingClassifier from sklearn.tree import DecisionTreeClassifier from sklearn.preprocessing import StandardScaler i

  • sklearn的predict_proba使用说明

    发现个很有用的方法--predict_proba 今天在做数据预测的时候用到了,感觉很不错,所以记录分享一下,以后可能会经常用到. 我的理解:predict_proba不同于predict,它返回的预测值为,获得所有结果的概率.(有多少个分类结果,每行就有多少个概率,以至于它对每个结果都有一个可能,如0.1就有两个概率) 举例: 获取数据及预测代码: from sklearn.linear_model import LogisticRegression import numpy as np tr

  • 浅谈sklearn中predict与predict_proba区别

    predict_proba 返回的是一个 n 行 k 列的数组,列是标签(有排序), 第 i 行 第 j 列上的数值是模型预测 第 i 个预测样本为某个标签的概率,并且每一行的概率和为1. predict 直接返回的是预测 的标签. 具体见下面示例: # conding :utf-8 from sklearn.linear_model import LogisticRegression import numpy as np x_train = np.array([[1,2,3], [1,3,4]

  • Python sklearn中的.fit与.predict的用法说明

    我就废话不多说了,大家还是直接看代码吧~ clf=KMeans(n_clusters=5) #创建分类器对象 fit_clf=clf.fit(X) #用训练器数据拟合分类器模型 clf.predict(X) #也可以给新数据数据对其预测 print(clf.cluster_centers_) #输出5个类的聚类中心 y_pred = clf.fit_predict(X) #用训练器数据X拟合分类器模型并对训练器数据X进行预测 print(y_pred) #输出预测结果 补充知识:sklearn中

  • 一文搞懂Python Sklearn库使用

    目录 1.LabelEncoder 2.OneHotEncoder 3.sklearn.model_selection.train_test_split随机划分训练集和测试集 4.pipeline 5 perdict 直接返回预测值 6 sklearn.metrics中的评估方法 7 GridSearchCV 8 StandardScaler 9 PolynomialFeatures 4.10+款机器学习算法对比 4.1 生成数据 4.2 八款主流机器学习模型 4.3 树模型 - 随机森林 4.

  • 关于数据库连接池Druid使用说明

    根据综合性能,可靠性,稳定性,扩展性,易用性等因素替换成最优的数据库连接池. Druid:druid-1.0.29 数据库 Mysql.5.6.17 替换目标:替换掉C3P0,用druid来替换 替换原因: 1.性能方面 hikariCP>druid>tomcat-jdbc>dbcp>c3p0 .hikariCP的高性能得益于最大限度的避免锁竞争. 2.druid功能最为全面,sql拦截等功能,统计数据较为全面,具有良好的扩展性. 3.综合性能,扩展性等方面,可考虑使用druid或

  • Linux 新的API signalfd、timerfd、eventfd使用说明

    三种新的fd加入linux内核的的版本: signalfd:2.6.22 timerfd:2.6.25 eventfd:2.6.22 三种fd的意义: lsignalfd 传统的处理信号的方式是注册信号处理函数:由于信号是异步发生的,要解决数据的并发访问,可重入问题.signalfd可以将信号抽象为一个文件描述符,当有信号发生时可以对其read,这样可以将信号的监听放到select.poll.epoll等监听队列中. ltimerfd 可以实现定时器的功能,将定时器抽象为文件描述符,当定时器到期

  • 基于Bootstrap的标签页组件及bootstrap-tab使用说明

    bootstrap-tab bootstrap-tab组件是对原生的bootstrap-tab组件的封装,方便开发者更方便地使用,主要包含以下功能: tab页初始化 关闭tab页 新增tab 显示tab页 获取tab页ID 使用 Step1 :引入样式 <link rel="stylesheet" href="bootstrap/css/bootstrap.min.css" rel="external nofollow" > <

  • php header 详细使用说明与使用心得第1/2页

    不管页面有多少header,它会执行最后一个,不过是有条件的,例如: header('Location:http://www.jb51.net'); header('Location:http://www.g.cn'); header('Location:http://www.baidu.com'); 这个就会跳到百度 header('Location:http://www.jb51.net');echo '我们'; header('Location:http://www.g.cn'); hea

  • C#中yield用法使用说明

    在迭代器块中用于向枚举数对象提供值或发出迭代结束信号.它的形式为下列之一: yield return <expression>; yield break; 备注: 计算表达式并以枚举数对象值的形式返回:expression 必须可以隐式转换为迭代器的 yield 类型. yield 语句只能出现在 iterator 块中,该块可用作方法.运算符或访问器的体.这类方法.运算符或访问器的体受以下约束的控制:不允许不安全块. 方法.运算符或访问器的参数不能是 ref 或 out. yield 语句不

  • php header()函数使用说明

    header()函数使用说明: 一.作用:   ~~~~~~~~~          PHP只是以HTTP协议将HTML文档的标头送到浏览器,告诉浏览器具体怎么处理这个页面,至于传送的内容则需要熟悉一下HTTP协议了,与PHP无关了,可参照http://www.w3.org/Protocols/rfc2616/rfc2616.          传统的标头一定包含下面三种标头之一,并只能出现一次.          Location:  xxxx:yyyy/zzzz          Conte

随机推荐