AdaBoost 二分类问题

news/2024/11/16 15:23:01/

代码功能

生成数据集:
使用 make_classification 创建一个模拟分类问题的数据集。
数据集包含 10 个特征,其中 5 个是有用特征,2 个是冗余特征。
数据集划分:
将数据分为训练集(70%)和测试集(30%)以便评估模型性能。
定义 AdaBoost 模型:
使用 DecisionTreeClassifier 作为弱分类器(基础分类器),设置 max_depth=1 表示单层决策树(决策桩)。
设置 n_estimators=50,表示最多构建 50 个弱分类器。
设置 learning_rate=1.0,控制每个弱分类器对最终模型的贡献权重。
训练和预测:
使用 fit 方法在训练集上训练模型。
使用 predict 方法在测试集上进行预测。
评估模型:
使用 accuracy_score 和 classification_report 评估模型的准确率和分类性能。
特征重要性分析(可选):
提取模型中的特征重要性,分析每个特征对模型的贡献。
在这里插入图片描述

代码

import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.ensemble import AdaBoostClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, classification_report# 1. 生成数据集
X, y = make_classification(n_samples=500,  # 样本数n_features=10,  # 特征数n_informative=5,  # 有用特征数n_redundant=2,  # 冗余特征数random_state=42,  # 随机种子
)# 将数据划分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 2. 定义并训练 AdaBoost 模型
base_estimator = DecisionTreeClassifier(max_depth=1)  # 弱分类器:决策树(单层)
ada_model = AdaBoostClassifier(base_estimator=base_estimator,  # 基础模型n_estimators=50,  # 最大弱分类器数量learning_rate=1.0,  # 学习率random_state=42,
)
ada_model.fit(X_train, y_train)# 3. 预测
y_pred = ada_model.predict(X_test)# 4. 评估模型
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)
print("\nClassification Report:\n", classification_report(y_test, y_pred))# 5. 重要性分析(可选)
feature_importances = ada_model.feature_importances_
print("\nFeature Importances:")
for i, importance in enumerate(feature_importances):print(f"Feature {i + 1}: {importance:.4f}")

http://www.ppmy.cn/news/1547483.html

相关文章

Llama架构及代码详解

Llama的框架图如图: 源码中含有大量分布式训练相关的代码,读起来比较晦涩难懂,所以我们对llama自顶向下进行了解析及复现,我们对其划分成三层,分别是顶层、中层、和底层,如下: Llama的整体组成…

搜维尔科技:CyberGlove触觉反馈数据手套出货调试

CyberGlove触觉反馈数据手套出货调试 搜维尔科技:CyberGlove触觉反馈数据手套出货调试

异步提交Django

在Django中,异步提交通常涉及前端使用AJAX(Asynchronous JavaScript and XML)或其他现代技术(如Fetch API)发送请求,而后端处理这些请求并返回响应。这种方式允许网页在不重新加载的情况下与服务器进行交互…

图论-代码随想录刷题记录[JAVA]

文章目录 前言深度优先搜索理论基础所有可达路径岛屿数量岛屿最大面积孤岛的总面积沉没孤岛水流问题Floyd 算法dijkstra(朴素版)最小生成树之primkruskal算法 前言 新手小白记录第一次刷代码随想录 1.自用 抽取精简的解题思路 方便复盘 2.代码尽量多加注…

前端学习八股资料CSS(二)

更多详情:爱米的前端小笔记,更多前端内容,等你来看!这些都是利用下班时间整理的,整理不易,大家多多👍💛➕🤔哦!你们的支持才是我不断更新的动力!找…

Flutter下拉刷新上拉加载的简单实现方式二

一个简单的Flutter应用程序,展示了如何实现下拉刷新和上拉加载更多的功能。 import package:flutter/cupertino.dart; import package:flutter/material.dart;class MyRefreshDemoPage extends StatefulWidget {const MyRefreshDemoPage({super.key});overrideMyRe…

计算机视觉 ---常见图像文件格式及其特点

常见的图像文件格式及其特点如下: JPEG(Joint Photographic Experts Group) 特点: 有损压缩:通过丢弃一些图像数据来实现高压缩比,能显著减小文件大小,适合用于存储照片等色彩丰富的图像。但过…

pycharm连接oracle数据库查询数据

查询当前python版本在 Terminal中使用命令 pip version Python-oracledb 的默认精简模式可以连接到 Oracle 数据库 12.1 或更高版本。如果要连接到 Oracle 数据库 11.2,则需要通过在代码中调用 oracledb.init_oracle_client() 来启用厚模式。否则会提示版本不支持。…