浅谈sklearn中predict与predict_proba区别

predict_proba 返回的是一个 n 行 k 列的数组,列是标签(有排序), 第 i 行 第 j 列上的数值是模型预测 第 i 个预测样本为某个标签的概率,并且每一行的概率和为1。

predict 直接返回的是预测 的标签。

具体见下面示例:

# conding :utf-8
from sklearn.linear_model import LogisticRegression
import numpy as np
x_train = np.array([[1,2,3],
          [1,3,4],
          [2,1,2],
          [4,5,6],
          [3,5,3],
          [1,7,2]]) 

y_train = np.array([3, 3, 3, 2, 2, 2]) 

x_test = np.array([[2,2,2],
          [3,2,6],
          [1,7,4]]) 

clf = LogisticRegression()
clf.fit(x_train, y_train) 

# 返回预测标签
print(clf.predict(x_test)) 

# 返回预测属于某标签的概率
print(clf.predict_proba(x_test)) 

# [2 3 2]
#
# [[0.56651809 0.43348191]
# [0.15598162 0.84401838]
# [0.86852502 0.13147498]]
# 分析结果:
# 标签是 2,3 共两个,所以predict_proba返回的为2列,且是排序的(第一列为标签2,第二列为标签3),
# 返回矩阵的行数是测试样本个数 因此为3行
# 预测[2,2,2]的标签是2的概率为0.56651809,3的概率为0.43348191
#
# 预测[3,2,6]的标签是2的概率为0.15598162,3的概率为0.84401838
#
# 预测[1,7,4]的标签是2的概率为0.86852502,3的概率为0.13147498 

补充知识:sklearn中predict与predict_proba的识别结果不一致

今天训练了好久的决策树模型在测试的时候发现个bug,使用predict得到的结果居然不是predict_proba中最大数值的索引!因为脚本中需要模型的置信度,所以希望拿到predict_proba的类别概率。

经过胡乱分析发现predict_proba得到的维度比总类别数少了几个,经过测试发现就是这个造成的,即训练集中有部分类别样本数为0。这个问题比较隐蔽,记录一下方便天涯沦落人绕坑。

Tip:在sklearn的train_test_split中有一个参数可以强制测试集和训练集的数据分布一致,也就不会导致缺类别的问题。

以上这篇浅谈sklearn中predict与predict_proba区别就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

(0)

相关推荐

  • Python使用sklearn实现的各种回归算法示例

    本文实例讲述了Python使用sklearn实现的各种回归算法.分享给大家供大家参考,具体如下: 使用sklearn做各种回归 基本回归:线性.决策树.SVM.KNN 集成方法:随机森林.Adaboost.GradientBoosting.Bagging.ExtraTrees 1. 数据准备 为了实验用,我自己写了一个二元函数,y=0.5*np.sin(x1)+ 0.5*np.cos(x2)+0.1*x1+3.其中x1的取值范围是0~50,x2的取值范围是-10~10,x1和x2的训练集一共有5

  • 对Keras中predict()方法和predict_classes()方法的区别说明

    1 predict()方法 当使用predict()方法进行预测时,返回值是数值,表示样本属于每一个类别的概率,我们可以使用numpy.argmax()方法找到样本以最大概率所属的类别作为样本的预测标签. 2 predict_classes()方法 当使用predict_classes()方法进行预测时,返回的是类别的索引,即该样本所属的类别标签.以卷积神经网络中的图片分类为例说明,代码如下: 补充知识:keras中model.evaluate.model.predict和model.predi

  • python sklearn包——混淆矩阵、分类报告等自动生成方式

    preface:做着最近的任务,对数据处理,做些简单的提特征,用机器学习算法跑下程序得出结果,看看哪些特征的组合较好,这一系列流程必然要用到很多函数,故将自己常用函数记录上.应该说这些函数基本上都会用到,像是数据预处理,处理完了后特征提取.降维.训练预测.通过混淆矩阵看分类效果,得出报告. 1.输入 从数据集开始,提取特征转化为有标签的数据集,转为向量.拆分成训练集和测试集,这里不多讲,在上一篇博客中谈到用StratifiedKFold()函数即可.在训练集中有data和target开始. 2.

  • 深入浅析Python 中的sklearn模型选择

    1.主要功能如下: 1.classification分类 2.Regression回归 3.Clustering聚类 4.Dimensionality reduction降维 5.Model selection模型选择 6.Preprocessing预处理 2.主要模块分类: 1.sklearn.base: Base classes and utility function基础实用函数 2.sklearn.cluster: Clustering聚类 3.sklearn.cluster.biclu

  • 浅谈sklearn中predict与predict_proba区别

    predict_proba 返回的是一个 n 行 k 列的数组,列是标签(有排序), 第 i 行 第 j 列上的数值是模型预测 第 i 个预测样本为某个标签的概率,并且每一行的概率和为1. predict 直接返回的是预测 的标签. 具体见下面示例: # conding :utf-8 from sklearn.linear_model import LogisticRegression import numpy as np x_train = np.array([[1,2,3], [1,3,4]

  • 浅谈mybatis中的#和$的区别 以及防止sql注入的方法

    mybatis中的#和$的区别 1. #将传入的数据都当成一个字符串,会对自动传入的数据加一个双引号.如:order by #user_id#,如果传入的值是111,那么解析成sql时的值为order by "111", 如果传入的值是id,则解析成的sql为order by "id". 2. $将传入的数据直接显示生成在sql中.如:order by $user_id$,如果传入的值是111,那么解析成sql时的值为order by user_id,  如果传入的

  • 浅谈mybatis中的#和$的区别

    1. #将传入的数据都当成一个字符串,会对自动传入的数据加一个双引号.如:order by #user_id#,如果传入的值是111,那么解析成sql时的值为order by "111", 如果传入的值是id,则解析成的sql为order by "id". 2. $将传入的数据直接显示生成在sql中.如:order by $user_id$,如果传入的值是111,那么解析成sql时的值为order by user_id, 如果传入的值是id,则解析成的sql为ord

  • 浅谈Java中replace与replaceAll区别

    看门见山 1.java中replace API: replace(char oldChar, char newChar):寓意为:返回一个新的字符串,它是通过用 newChar 替换此字符串中出现的所有 oldChar 得到的. replace(CharSequence target, CharSequence replacement):寓意为:使用指定的字面值替换序列替换此字符串所有匹配字面值目标序列的子字符串. replaceAll(String regex, String replacem

  • 浅谈java 中equals和==的区别

    本文实例为大家分享了java 中equals和==的区别的具体代码,供大家参考,具体内容如下 java9举例代码: String str1 = "abc"; String str2 = "abc"; String str3 = new String("abc"); String str4 = new String("abc"); 当: str1 == str2    输出:true 当:str1.equals(str2); 输

  • 浅谈SpringMVC中Interceptor和Filter区别

    Interceptor 主要作用:拦截用户请求,进行处理,比如判断用户登录情况.权限验证,只要针对Controller请求进行处理,是通过HandlerInterceptor. Interceptor分两种情况,一种是对会话的拦截,实现spring的HandlerInterceptor接口并注册到mvc的拦截队列中,其中preHandle()方法在调用Handler之前进行拦截(上图步骤3),postHandle()方法在视图渲染之前调用(上图步骤5),afterCompletion()方法在返

  • 浅谈c#中const与readonly区别

    const 的概念就是一个包含不能修改的值的变量. 常数表达式是在编译时可被完全计算的表达式.因此不能从一个变量中提取的值来初始化常量. 如果 const int a = b+1;b是一个变量,显然不能再编译时就计算出结果,所以常量是不可以用变量来初始化的. readonly 允许把一个字段设置成常量,但可以执行一些运算,可以确定它的初始值. 因为 readonly 是在计算时执行的,当然它可以用某些变量初始化. readonly 是实例成员,所以不同的实例可以有不同的常量值,这使readonl

  • 浅谈python中copy和deepcopy中的区别

    在下是个编程爱好者,最近将魔爪伸向了Python编程.....遇到copy和deepcopy感到很困惑,现在针对这两个方法进行区分,一种是浅复制(copy),一种是深度复制(deepcopy). 首先说一下deepcopy,所谓的深度复制,在这里我理解的是完全复制然后变成一个新的对象,复制的对象和被复制的对象没有任何关系,彼此之间无论怎么改变都相互不影响. 然后说一下copy,在这里我分为两类来说,一种是字典数据类型的copy函数,一种是copy包的copy函数. 一.字典数据类型的copy函数

  • 浅谈java中math类中三种取整函数的区别

    math类中三大取整函数 1.ceil 2.floor 3.round 其实三种取整函数挺简单的.只要记住三个函数名翻译过来的汉语便能轻松理解三大函数,下面一一介绍 1.ceil,意思是天花板,java中叫做向上取整,大于等于该数字的最接近的整数 例: math.ceil(13.2)=14 math.ceil(-13.2)=-13 2.floor,意思是地板,java中叫做向下取整,小于等于该数字的最接近的整数 例: math.floor(13.2)=13 math.floor(-13.2)=-

  • 浅谈Java中Collection和Collections的区别

    1.java.util.Collection 是一个集合接口.它提供了对集合对象进行基本操作的通用接口方法.Collection接口在Java 类库中有很多具体的实现.Collection接口的意义是为各种具体的集合提供了最大化的统一操作方式. Collection ├List │├LinkedList │├ArrayList │└Vector │ └Stack └Set 2.java.util.Collections 是一个包装类.它包含有各种有关集合操作的静态多态方法.此类不能实例化,就像一

随机推荐