【机器学习】CatBoost 模型实践:回归与分类的全流程解析

news/2024/12/4 18:42:13/

一. 引言

本篇博客首发于掘金 https://juejin.cn/post/7441027173430018067。
PS:转载自己的文章也算原创吧。

在机器学习领域,CatBoost 是一款强大的梯度提升框架,特别适合处理带有类别特征的数据。本篇博客以脱敏后的保险数据集为例,展示如何利用 CatBoost 完成分类回归任务,并以可视化的方式解析特征重要性与结果。

我们将完成以下任务:

  1. 回归任务:预测保险索赔金额。
  2. 分类任务:判断保险案件是否需要调查。
  3. 可视化分析:利用散点图与分割线展示结果。

二. CatBoost 模型简介

CatBoost 是由俄罗斯搜索巨头 Yandex 于 2017 年开源的机器学习库,其名称来源于 “Category” 和 “Boosting” 的组合,旨在高效处理类别特征的梯度提升算法。与其他模型(如 XGBoost 和 LightGBM)相比,CatBoost 具有以下优势:

  • 支持类别特征:无需对类别特征进行独热编码,直接处理类别数据,避免数据膨胀。
  • 对缺失值的鲁棒性:无需特殊预处理即可直接处理缺失值。
  • 防止过拟合:内置多种正则化手段,减少梯度偏差和预测偏移,提高模型的准确性和泛化能力。
  • 对称树结构:采用对称决策树(Oblivious Trees),在每个层级使用相同的特征和分割点,提升训练和预测效率。

三. 实战项目环境与数据准备

本项目使用了脱敏后的保险数据集,包含以下特征:

  • 类别特征:险种代码、出险原因、医疗责任类别等。
  • 数值特征:基本保额、索赔金额等。
  • 标签:是否需要调查(分类任务)。

所有数据均已脱敏,支持迁移至其他表格数据集。

因为不好分享,所以后续第七节补充了一个基于sklearn "California Housing"数据集的流程代码与说明。


四. 回归任务:预测保险索赔金额

数据预处理

回归任务中,我们根据特征预测索赔金额。以下是数据清洗与预处理的关键步骤:

  1. 过滤无效数据:移除缺失或非法值的记录。
  2. 特征转换:将类别特征转为字符串类型。
  3. 分割数据集:按 80% 和 20% 的比例划分训练集与测试集。

4.1 模型训练与评估

我们使用 CatBoost 进行回归建模,模型参数包括:

  • 学习率:0.02
  • 深度:8
  • 迭代次数:10,000(支持提前停止)

以下是模型的关键代码:

from catboost import CatBoostRegressor# 初始化 CatBoost 回归模型
cat_regressor = CatBoostRegressor(iterations=10000,learning_rate=0.02,depth=8,eval_metric='RMSE',early_stopping_rounds=1500,random_seed=42
)# 训练模型
cat_regressor.fit(X_train, y_train,cat_features=categorical_features_indices,eval_set=(X_test, y_test),verbose=100
)

4.2 特征重要性分析

特征重要性是衡量特征对模型预测贡献程度的指标,可以帮助我们更好地理解模型。

# 获取特征重要性
feature_importances = cat_regressor.get_feature_importance()
feature_names = X_train.columns# 可视化特征重要性
import matplotlib.pyplot as plt
importance_df = pd.DataFrame({'Feature': feature_names,'Importance': feature_importances
}).sort_values(by='Importance', ascending=True)plt.figure(figsize=(10, 6))
plt.barh(importance_df['Feature'], importance_df['Importance'], color='salmon')
plt.xlabel('特征重要性')
plt.ylabel('特征名称')
plt.title('CatBoost 特征重要性分析')
plt.show()

结果展示
在这里插入图片描述


4.3 模型评估

我们可以均方误差 (MSE) 以及 平均绝对误差 (MAE) 来评估模型在测试集上的回归性能,同时展示模型的学习曲线:

# 获取训练和测试集的 RMSE
evals_result = cat_regressor.get_evals_result()
train_rmse = evals_result['learn']['RMSE']
test_rmse = evals_result['validation']['RMSE']# 绘制 RMSE 曲线
plt.figure(figsize=(10, 6))
plt.plot(train_rmse, label='训练集 RMSE')
plt.plot(test_rmse, label='测试集 RMSE')
plt.title('训练与测试集的 RMSE 学习曲线')
plt.xlabel('迭代次数')
plt.ylabel('RMSE')
plt.legend()
plt.show()

五. 分类任务:判别是否调查

5.1 数据标注与模型选择

分类任务以 是否调查 作为标签(1 表示需要调查,0 表示无需调查),特征包括所有数值和类别字段。

为了完成分类任务,我们选用 CatBoostClassifier。模型参数类似于回归模型,分类评估指标包括准确率、混淆矩阵和分类报告。


5.2 训练结果与模型评估

训练结果显示,分类准确率达 94.0%。以下是模型的分类报告:

分类报告 (训练集):precision    recall  f1-score   support0       0.96      0.98      0.97     130871       0.74      0.57      0.64      1354accuracy                           0.94     14441macro avg       0.85      0.77      0.80     14441
weighted avg       0.94      0.94      0.94     14441
5.3 代码示例
from catboost import CatBoostClassifier# 初始化分类
cat_classifier = CatBoostClassifier(iterations=1000,learning_rate=0.02,depth=8,eval_metric='Accuracy',early_stopping_rounds=150,random_seed=42
)# 模型训练
cat_classifier.fit(X_train, y_train,cat_features=categorical_features_indices,eval_set=(X_test, y_test),verbose=100
)

六. 可视化分析

为更直观地理解模型,我们利用散点图和分割线对预测结果进行展示:

  • 散点图:展示实际金额与预测金额的分布。
  • 分割线:通过 KMeans 聚类划分四个金额档次。

以下代码生成散点图与分割线:

# 使用 KMeans 聚类生成分割线
from sklearn.cluster import KMeanskmeans = KMeans(n_clusters=4, random_state=42)
df['cluster'] = kmeans.fit_predict(df[['预测金额']])# 绘制散点图
plt.figure(figsize=(12, 8))
plt.scatter(df['预测金额'], df['是否调查'], c=df['cluster'], cmap='tab10')
plt.title("预测金额与是否调查的散点图")
plt.xlabel("预测金额")
plt.ylabel("是否调查")
plt.colorbar(label='Cluster')
plt.show()

散点图展示

在这里插入图片描述


七. 补充学习

7.1 基础数据集

California Housing 数据集包含加利福尼亚州 20,640 个街区的人口、住房和收入信息。目标是预测每个街区的房价中位数 MedHouseVal

数据特征

  1. MedInc:街区的收入中位数。
  2. HouseAge:街区住房的平均年龄。
  3. AveRooms:每个街区的平均房间数。
  4. AveBedrms:每个街区的平均卧室数。
  5. Population:街区的总人口。
  6. AveOccup:每户的平均人数。
  7. Latitude:街区的纬度。
  8. Longitude:街区的经度。

7.2 实践步骤

7.2.1 导入数据与预处理

我们使用 Scikit-learn 加载数据并进行预处理。

from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import pandas as pd# 加载 California Housing 数据集
data = fetch_california_housing(as_frame=True)
df = data.frame# 特征和目标变量
X = df.drop(columns="MedHouseVal")
y = df["MedHouseVal"]# 数据划分
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 数据标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)print(f"训练集大小: {X_train.shape}, 测试集大小: {X_test.shape}")

训练集大小: (16512, 8), 测试集大小: (4128, 8)


7.2.2 训练 CatBoost 回归模型

使用 CatBoost 对房价进行预测。

from catboost import CatBoostRegressor
from sklearn.metrics import mean_squared_error, mean_absolute_error# 初始化 CatBoost 回归模型
cat_regressor = CatBoostRegressor(iterations=1000,learning_rate=0.1,depth=6,eval_metric="RMSE",random_seed=42,verbose=100
)# 模型训练
cat_regressor.fit(X_train, y_train, eval_set=(X_test, y_test), verbose=100, early_stopping_rounds=50)# 模型预测
y_pred_train = cat_regressor.predict(X_train)
y_pred_test = cat_regressor.predict(X_test)# 模型评估
mse_train = mean_squared_error(y_train, y_pred_train)
mse_test = mean_squared_error(y_test, y_pred_test)
mae_test = mean_absolute_error(y_test, y_pred_test)print(f"训练集均方误差 (MSE): {mse_train}")
print(f"测试集均方误差 (MSE): {mse_test}")
print(f"测试集平均绝对误差 (MAE): {mae_test}")

输出如下:

0:	learn: 1.0934740	test: 1.0841841	best: 1.0841841 (0)	total: 1.24s	remaining: 20m 38s
100:	learn: 0.4867395	test: 0.5154868	best: 0.5154868 (100)	total: 1.54s	remaining: 13.7s
200:	learn: 0.4320149	test: 0.4798269	best: 0.4798269 (200)	total: 1.8s	remaining: 7.18s
300:	learn: 0.4020581	test: 0.4657293	best: 0.4657293 (300)	total: 2.07s	remaining: 4.8s
400:	learn: 0.3803801	test: 0.4582868	best: 0.4582868 (400)	total: 2.35s	remaining: 3.5s
500:	learn: 0.3633580	test: 0.4534430	best: 0.4534430 (500)	total: 2.61s	remaining: 2.6s
600:	learn: 0.3488402	test: 0.4491723	best: 0.4491723 (600)	total: 2.89s	remaining: 1.92s
700:	learn: 0.3358611	test: 0.4461323	best: 0.4461323 (700)	total: 3.17s	remaining: 1.35s
800:	learn: 0.3234759	test: 0.4431320	best: 0.4431320 (800)	total: 3.44s	remaining: 854ms
900:	learn: 0.3126821	test: 0.4403978	best: 0.4403978 (900)	total: 3.71s	remaining: 407ms
999:	learn: 0.3025414	test: 0.4386906	best: 0.4386902 (998)	total: 3.97s	remaining: 0usbestTest = 0.438690174
bestIteration = 998Shrink model to first 999 iterations.
训练集均方误差 (MSE): 0.09158491090576551
测试集均方误差 (MSE): 0.19244906768098075
测试集平均绝对误差 (MAE): 0.28701415230111493

7.2.3 可视化预测结果

展示预测值与实际值的对比,以及模型的特征重要性。

实际值与预测值对比
import matplotlib.pyplot as plt# 对比测试集的预测值和实际值
plt.figure(figsize=(10, 6))
plt.scatter(range(len(y_test)), y_test, color="blue", label="真实值", alpha=0.6)
plt.scatter(range(len(y_pred_test)), y_pred_test, color="red", label="预测值", alpha=0.6)
plt.title("真实房价与预测房价对比")
plt.xlabel("样本索引")
plt.ylabel("房价中位数")
plt.legend()
plt.show()

特征重要性分析
# 特征重要性可视化
feature_importances = cat_regressor.get_feature_importance()
feature_names = data.feature_namesplt.figure(figsize=(10, 6))
plt.barh(feature_names, feature_importances, color="skyblue")
plt.title("CatBoost 特征重要性")
plt.xlabel("重要性得分")
plt.ylabel("特征名称")
plt.show()

在这里插入图片描述


7.3 数据结果

  • 模型评估结果:
    • 训练集均方误差 (MSE): 0.09158491090576551
    • 测试集均方误差 (MSE): 0.19244906768098075
    • 测试集平均绝对误差 (MAE): 0.28701415230111493
  • 特征重要性解读:
    根据特征重要性分析,MedInc(收入中位数)对预测房价的影响最大,而经纬度特征(Latitude 和 Longitude)也提供了显著的信息。

八. 总结

通过本项目,我们完成了基于 CatBoost 的回归分类建模,并展示了预测结果的可视化。CatBoost 的强大功能和易用性使其在处理类别特征和缺失值的数据中表现优异。

希望本篇博客能为大家带来启发,助力实际项目的落地实现。如果对您有所帮助,也欢迎点赞与分享😊。

源码已上传到:https://github.com/YYForReal/ML-DL-RL-Learning/blob/main/ML-Learning/Catboost/


http://www.ppmy.cn/news/1552350.html

相关文章

基于 LlamaFactory 的 LoRA 微调模型支持 vllm 批量推理的实现

背景 LlamaFactory 的 LoRA 微调功能非常便捷,微调后的模型,没有直接支持 vllm 推理,故导致推理速度不够快。 LlamaFactory 目前支持通过 VLLM API 进行部署,调用 API 时的响应速度,仍然没有vllm批量推理的速度快。 …

FreeRTOS之ARM CR5栈结构操作示意图

FreeRTOS之ARM CR5栈结构操作示意图 1 FreeRTOS源码下载地址2 ARM CR5栈结构操作宏和接口2.1 portSAVE_CONTEXT宏2.1.1 portSAVE_CONTEXT源码2.1.2 portSAVE_CONTEXT宏操作栈结构变化示意图 2.2 portRESTORE_CONTEXT宏2.2.1 portRESTORE_CONTEXT源码2.2.2 portRESTORE_CONTEXT宏…

Nginx管理维护运维规范

Nginx管理维护运维规范 一、版本约束 软件的不同版本,在使用起来都有可能带来不可预知的影响,因此需要统一整理,固定下来,不允许轻易变更。 在Nginx开源版官网,点击右侧download可以看到各个版本的Nginx,…

第六届金盾信安杯Web题解

比赛一共4道Web题,比赛时只做出三道,那道文件上传没有做出来,所以这里是另外三道题的WP 分别是 fillllll_put hoverfly ssrf fillllll_put 涉及: 绕过exit() 死亡函数 php://filter 伪协议配合base64加解密 一句话木马 题目源码: $content参数在开头被…

Linux笔试题(自己整理,已做完,选择题)

详细Linux内容查看:day04【入门】Linux系统操作-CSDN博客 1、部分笔试题 本文的笔试题,主要是为了复习学习的day04【入门】Linux系统操作-CSDN博客的相关知识点。后续还会更新一些面试相关的题目。 欢迎一起学习

MongoDB集群分片安装部署手册

文章目录 一、集群规划1.1 集群安装规划1.2 端口规划1.3 目录创建 二、mongodb安装(三台均需要操作)2.1 下载、解压2.2 配置环境变量 三、mongodb组件配置3.1 配置config server的副本集3.1.1 config配置文件3.1.2 config server启动3.1.3 初始化config …

链表的分类以及双向链表的实现

1.链表的分类 1.1单向链表与双向链表的区别 单向链表只能从头结点遍历到尾结点。 双向链表不仅能从头结点遍历到尾结点,还能从尾结点遍历到头结点。 1.2带头链表与不带头链表的区别 1.3循环链表与不循环链表的区别(也称为带环链表与不带环链表&#x…

【嵌入式系统设计】LES3~5:Cortex-M4系统架构(上)第1节 ARM处理器,M4内核处理器,M4调试跟踪接口

关注作者了解更多 我的其他CSDN专栏 过程控制系统 工程测试技术 虚拟仪器技术 可编程控制器 工业现场总线 数字图像处理 智能控制 传感器技术 嵌入式系统 复变函数与积分变换 单片机原理 线性代数 大学物理 热工与工程流体力学 数字信号处理 光电融合集成电路…