混淆矩阵是用于评估分类模型性能的一种常用工具,通过可视化混淆矩阵可以直观地了解模型在不同类别上的分类情况。下面是如何很好地可视化混淆矩阵的方法:
以下是一个示例代码,展示了如何使用Python和matplotlib库可视化混淆矩阵:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
# 准备数据
y_true = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2])
y_pred = np.array([0, 0, 2, 0, 2, 1, 0, 1, 2])
# 计算混淆矩阵
cm = confusion_matrix(y_true, y_pred)
# 可视化混淆矩阵
plt.imshow(cm, cmap=plt.cm.Blues)
# 添加颜色条
plt.colorbar()
# 添加坐标轴标签
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
# 添加类别标签
tick_marks = np.arange(len(np.unique(y_true)))
plt.xticks(tick_marks, np.unique(y_true))
plt.yticks(tick_marks, np.unique(y_true))
# 显示图像
plt.show()
这样,就可以得到一个清晰易懂的混淆矩阵可视化图像。对于更复杂的混淆矩阵,可以根据需要进行进一步的美化和调整。
领取专属 10元无门槛券
手把手带您无忧上云