前面的分类算法中,我们只用到了一个指标,准确率(Accuracy),也就是正确分类的样本数占总样本数的比例。现在让我们来思考一个问题,如果有一种癌症,1000个人中只有1个人会得,也就是患这个癌症的概率为0.1%。那么这个时候,我们不用机器学习,给我1000个人预测是否患癌,我只要全部猜没有,那么我就只会有1个人判错,我的准确率达到了99.9%。那么如果我们用机器学习来训练出一个预测一个人是否患有这个癌症的模型,就算这个模型最后的准确率达到了98%,那也是没有意义的。还没有我们人来“猜”的对。怎么办呢?怎么防止我们的模型进入这样的高准确率陷阱呢?
我们来认识下混淆矩阵(confusion matrix), 精准率(precision),召回率(recall).
什么是混淆矩阵,精准率,召回率?
拿二分类问题举例,我们把样本的真实分类值作为一个维度,把样本预测分类值作为一个维度,两个维度组成的矩阵就是混淆矩阵了。那么对于二分类问题,样本的真实值与预测值有4组可能的组合。它们组成了混淆矩阵。
TN(True Negative): 样本的真实值是0,预测值也是0,预测0正确。
FP(False Positive): 样本的真实值是0,预测值是1,错误的把0预测成1了。
FN(False Negative): 样本的真实值是1,预测值是0,错误的把1预测成0 了。
TP(True Positive): 样本的真实值是1,预测值是1,预测1正确。
精确率:在所有预测的1类的样本(TP+FP)中,预测正确的样本(TP)所占有的比例。举个例子,在10000个人中,我们预测出了20个患癌的人,在这20个人中,真的是患癌的人有8人,没有患癌的有12人,那么精准率就是8/20.
召回率:在所有真实类别为1的样本(FN+TP)中,被正确预测为1的样本(TP)所占的比例。举个例子,10000个样本中,一共10个真的患癌症的人,而被预测出来患癌的人有8人,没有被预测出来的有2人
混淆矩阵有什么用?
混淆矩阵可以很清楚的看出样本预测正确或错误的情况,其对角线元素都是预测正确的样本。
精准率和召回率有什么用?
在倾斜比较厉害的数据样本中,高的准确率已经不足以来判断一个模型的性能好坏。这个时候,我们就要用到精准率和召回率了。
什么是倾斜比较厉害的数据呢?就是对于分类问题来说,有些分类样本在总样本中数量远远大于其他分类样本。 比如10000个样本里面,有9900都是0类,只有100个是1类的数据样本 。
代码实现:
我们使用了手写字识别的数据集,为了让该数据集变的倾斜,我们把数字9作为一类,非9的数字作为另一类。用逻辑回归来做二分类问题。
Scikit-learn中的混淆矩阵,精准率,召回率实现
Scikit-learn中的性能指标方法都在metrics中。混淆矩阵,精准率,召回率分别对应confusion_matrix, precision_score,recall_score方法,只需要把测试样本的真实值集合和预测值集合传进去就好了。
领取专属 10元无门槛券
私享最新 技术干货