分类算法可视化方法

server/2024/9/24 12:31:48/

可视化方法可以用于帮助理解分类算法的决策边界、性能和在不同数据集上的行为。

下面列举几个常见的可视化方法。

1. 决策边界可视化

这种方法用于可视化不同分类算法在二维特征空间中如何分隔不同类别。对于理解决策树、支持向量机(SVM)、逻辑回归和k近邻(k-NN)等模型的行为非常有用。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression# 生成一个二维的合成数据集
X, y = make_classification(n_samples=200, n_features=2, n_classes=2, n_clusters_per_class=1, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 定义分类器
classifiers = {'逻辑回归': LogisticRegression(),'支持向量机': SVC(),'决策树': DecisionTreeClassifier(),'k近邻': KNeighborsClassifier()
}# 可视化决策边界函数
def plot_decision_boundaries(X, y, model, ax, title):h = .02  # 网格步长# 创建网格x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))# 模型训练并预测model.fit(X, y)Z = model.predict(np.c_[xx.ravel(), yy.ravel()])Z = Z.reshape(xx.shape)# 绘制决策边界ax.contourf(xx, yy, Z, alpha=0.8)ax.scatter(X[:, 0], X[:, 1], c=y, edgecolors='k', marker='o')ax.set_title(title)ax.set_xlim(xx.min(), xx.max())ax.set_ylim(yy.min(), yy.max())# 创建子图
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.ravel()# 为每个分类器绘制决策边界
for idx, (name, clf) in enumerate(classifiers.items()):plot_decision_boundaries(X_train, y_train, clf, axes[idx], title=name)plt.tight_layout()
plt.show()

2. 混淆矩阵可视化

混淆矩阵是一种用来评估分类模型性能的工具,展示了预测类别与真实类别的匹配情况。

示例代码
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay# 使用逻辑回归分类器为例
model = LogisticRegression()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)# 计算混淆矩阵
cm = confusion_matrix(y_test, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['Class 0', 'Class 1'])# 绘制混淆矩阵
disp.plot(cmap=plt.cm.Blues)
plt.title('混淆矩阵')
plt.show()

3. 学习曲线

学习曲线展示了模型在训练集和验证集上的表现随训练样本数量的变化情况,用于检测模型是否欠拟合或过拟合。

示例代码
from sklearn.model_selection import learning_curve
from sklearn.model_selection import ShuffleSplit# 使用支持向量机分类器为例
model = SVC()
cv = ShuffleSplit(n_splits=10, test_size=0.2, random_state=0)
train_sizes, train_scores, test_scores = learning_curve(model, X, y, cv=cv, n_jobs=-1, train_sizes=np.linspace(0.1, 1.0, 5))# 计算平均和标准差
train_scores_mean = np.mean(train_scores, axis=1)
train_scores_std = np.std(train_scores, axis=1)
test_scores_mean = np.mean(test_scores, axis=1)
test_scores_std = np.std(test_scores, axis=1)# 绘制学习曲线
plt.figure(figsize=(8, 6))
plt.grid()plt.fill_between(train_sizes, train_scores_mean - train_scores_std,train_scores_mean + train_scores_std, alpha=0.1, color="r")
plt.fill_between(train_sizes, test_scores_mean - test_scores_std,test_scores_mean + test_scores_std, alpha=0.1, color="g")plt.plot(train_sizes, train_scores_mean, 'o-', color="r", label="训练分数")
plt.plot(train_sizes, test_scores_mean, 'o-', color="g", label="验证分数")plt.xlabel("训练样本数量")
plt.ylabel("得分")
plt.title("学习曲线")
plt.legend(loc="best")
plt.show()

这些可视化方法和代码示例可以更好地理解和展示分类算法的行为及其效果。


http://www.ppmy.cn/server/113506.html

相关文章

江协科技STM32学习- P11 中断系统,EXTI外部中断

🚀write in front🚀 🔎大家好,我是黄桃罐头,希望你看完之后,能对你有所帮助,不足请指正!共同学习交流 🎁欢迎各位→点赞👍 收藏⭐️ 留言📝​…

插入、希尔、冒泡、选择排序

目录 1.插入排序 2.希尔排序 3.冒泡排序 4.选择排序 5.完整代码以及时间测试 1.插入排序 即每次把要插入的元素插入已经有序的数组中&#xff0c;经过不断向前比较&#xff0c;来插入目标元素 void InsertSort(int* a, int n) {for (int i 0; i < n-1;i){int end i;…

TeeChart助力科研软件:高效实现数据可视化

在当今的科学研究中&#xff0c;数据可视化已经成为理解和传播复杂信息的关键工具。尤其是在物理研究领域&#xff0c;科学家们经常需要处理大量的数据&#xff0c;并通过可视化将这些数据转化为更易理解的形式。TeeChart作为一个强大且灵活的图形展示工具&#xff0c;能够帮助…

【Unity小技巧】物体遮挡轮廓描边效果

前言&#xff1a; 效果展示&#xff1a; 遮挡描边 Demo下载 所用插件 QuickOutline描边插件&#xff08;在Demo里&#xff09; 实现步骤 物体挂载Outline组件&#xff0c;做如下处理 Outline Mode&#xff08;描边模式&#xff09;&#xff1a;Outline Hidden(遮挡模式显示…

基础学习之——git 的使用方式

git 是一种分布式版本控制系统&#xff08;Distributed Version Control System, DVCS&#xff09;&#xff0c;用于有效地管理代码和文件的变更历史。它最初由林纳斯托瓦兹&#xff08;Linus Torvalds&#xff09;于2005年为管理Linux内核开发而设计&#xff0c;并很快因其效率…

[数据集][目标检测]智慧农业草莓叶子病虫害检测数据集VOC+YOLO格式4040张9类别

数据集格式&#xff1a;Pascal VOC格式YOLO格式(不包含分割路径的txt文件&#xff0c;仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数)&#xff1a;4040 标注数量(xml文件个数)&#xff1a;4040 标注数量(txt文件个数)&#xff1a;4040 标注…

算法学习:模拟

题源&#xff1a;回文日期 题目&#xff1a; 下面我们对题目进行分析&#xff0c;首先涉及到日期&#xff0c;我们很敏感的考虑到日期的合法性&#xff0c;而日期的合法性中又分为普通日期和特殊日期&#xff08;闰年二月&#xff09;。 再结合这道题目&#xff0c;对8位数的…

AIGC大模型智能抠图(清除背景):Sanster/IOPaint,python(2)

AIGC大模型智能抠图&#xff08;清除背景&#xff09;&#xff1a;Sanster/IOPaint&#xff0c;python&#xff08;2&#xff09; 在文章&#xff08;1&#xff09;的基础上&#xff0c;尝试用大模型扣除图中的某些主要景物。 1、首先&#xff0c;安装插件&#xff1a; pip i…