在机器学习领域,混淆矩阵是一个评估分类模型性能的重要工具。它不仅展示了模型预测的准确性,还揭示了模型在不同类别上的表现。本文介绍两种在Python中绘制混淆矩阵的方法:ConfusionMatrixDisplay()
和 imshow()
,以及两种好看的colorbar:coolwarm_r
,GnBu
, 以增强可视化效果。
目录
- ConfusionMatrixDisplay()
- 基本用法:
- 参数和方法:
- 示例:
- 示例代码:
- imshow()
- 基本用法:
- 参数:
- 示例:
- 示例代码:
- 两种 colorbar
ConfusionMatrixDisplay()
ConfusionMatrixDisplay()
是一个来自 scikit-learn 库的类,用于可视化混淆矩阵。
sklearn.metrics.ConfusionMatrixDisplay 的官方社区描述:
- 中文社区:https://scikit-learn.org.cn/view/582.html
- 英文社区:https://scikit-learn.org/stable/modules/generated/sklearn.metrics.ConfusionMatrixDisplay.html
基本用法:
ConfusionMatrixDisplay 可以通过以下方式创建:
python">from sklearn.metrics import ConfusionMatrixDisplay# 假设 cm 是一个混淆矩阵
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot()
参数和方法:
confusion_matrix
: 参数,一个形状为 (n_classes, n_classes) 的 ndarray,表示混淆矩阵。display_labels
: 参数,一个形状为 (n_classes,) 的 ndarray,默认为 None。用于设置绘图时的标签。如果为 None,则显示标签从 0 到 n_classes - 1。plot()
: 方法,绘制混淆矩阵的可视化。
示例:
示例代码:
python">from sklearn.metrics import ConfusionMatrixDisplay
import os
import matplotlib.pyplot as pltimport numpy as np
import numpy.random as npr
npr.seed(0)# Save path
save_path = './plot'
os.makedirs(save_path, exist_ok=True)# Generate random data 0~1
n = 10
data = npr.rand(n, n) * 0.8
for i in range(n):data[i, i] = 1.0# Plot confusion matrix
fig, ax = plt.subplots(figsize=(8, 8))cm = ConfusionMatrixDisplay(data, display_labels=np.arange(n))
cm.plot(ax=ax, cmap="GnBu", include_values=False, xticks_rotation=90) # GnBu, coolwarm_rax.set_xlabel('Trials', fontsize=20)
ax.set_ylabel('Trials', fontsize=20)plt.title(f'Confusion matrix', fontsize=30)
plt.tight_layout()plt.savefig(f'{save_path}/confu_mat_1-2.png', dpi=300)
plt.show()
imshow()
imshow()
是一个来自 Matplotlib 库的函数,用于在图形用户界面(GUI)中显示图像。这个函数可以处理多种类型的图像数据,包括灰度图和彩色图,是 Matplotlib 中用于图像显示的基础函数之一。
matplotlib.pyplot.imshow 的官方描述:https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.imshow.html
基本用法:
python">import matplotlib.pyplot as plt
import numpy as np# 创建一个随机数组作为图像数据
image_data = np.random.rand(10, 10)# 使用 imshow() 显示图像
plt.imshow(image_data)
plt.colorbar() # 显示颜色条
plt.show()
参数:
imshow()
函数接受多个参数,以下是一些常用的参数:
X
: 图像数据,可以是 2D 数组(灰度图)或 3D 数组(彩色图)。cmap
: 颜色映射表,用于定义颜色。例如,cmap=‘gray’ 表示灰度图,cmap=‘viridis’ 是一种常用的彩色映射。norm
: 归一化对象,用于调整数据值到 [0, 1] 范围。aspect
: 图像的纵横比,可以是 ‘auto’、‘equal’ 或一个数值。interpolation
: 插值方法,用于定义图像的缩放方式,如 ‘nearest’、‘bilinear’、‘bicubic’ 等。alpha
: 图像的透明度。
imshow()
返回一个 AxesImage 对象,这个对象包含了图像的显示信息,可以用来进一步定制图像的显示效果。
示例:
ConfusionMatrixDisplay()
内置函数定义了所绘制的混淆矩阵必须为方针,而imshow()
可以绘制行列数不等的矩形:
示例代码:
python">from mpl_toolkits.axes_grid1 import make_axes_locatableimport os
import matplotlib.pyplot as pltimport numpy as np
import numpy.random as npr
npr.seed(0)# Save path
save_path = './plot'
os.makedirs(save_path, exist_ok=True)# Generate random data 0~1
m = 6
n = 10
data = npr.rand(m, n) * 0.8
if m == n:for i in range(n):data[i, i] = 1.0fig, ax = plt.subplots(figsize=(n, m))
cm = ax.imshow(data, cmap='coolwarm_r', interpolation="nearest", vmin=0.0, vmax=1.0) # coolwarm_r, GnBu# # 绘制一条对角线
# ax.plot([-0.5, n + 0.5], [-0.5, n + 0.5], color='black', alpha=0.2)ax.set_xticks(np.arange(n))
ax.set_yticks(np.arange(m))ax.set_xticklabels(np.arange(n), fontsize=15, rotation=90)
ax.set_yticklabels(np.arange(m), fontsize=15)plt.xlabel('N', fontsize=20)
plt.ylabel('M', fontsize=20)plt.title(f'Confusion matrix', fontsize=30)divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="4%", pad=0.2)
cb = fig.colorbar(cm, cax=cax)
cb.ax.tick_params(labelsize=15)plt.tight_layout()plt.savefig(f'{save_path}/confu_mat_3-1.png', dpi=300)
plt.show()
两种 colorbar
-
coolwarm_r
-
GnBu
更多 colorbar:https://astromsshin.github.io/science/code/matplotlib_cm/index.html