梯度提升算法决策过程的逐步可视化

news/2024/10/30 11:30:35/

梯度提升算法是最常用的集成机器学习技术之一,该模型使用弱决策树序列来构建强学习器。这也是XGBoost和LightGBM模型的理论基础,所以在这篇文章中,我们将从头开始构建一个梯度增强模型并将其可视化。

梯度提升算法介绍

梯度提升算法(Gradient Boosting)是一种集成学习算法,它通过构建多个弱分类器,然后将它们组合成一个强分类器来提高模型的预测准确率。

梯度提升算法的原理可以分为以下几个步骤:

  1. 初始化模型:一般来说,我们可以使用一个简单的模型(比如说决策树)作为初始的分类器。
  2. 计算损失函数的负梯度:计算出每个样本点在当前模型下的损失函数的负梯度。这相当于是让新的分类器去拟合当前模型下的误差。
  3. 训练新的分类器:用这些负梯度作为目标变量,训练一个新的弱分类器。这个弱分类器可以是任意的分类器,比如说决策树、线性模型等。
  4. 更新模型:将新的分类器加入到原来的模型中,可以用加权平均或者其他方法将它们组合起来。
  5. 重复迭代:重复上述步骤,直到达到预设的迭代次数或者达到预设的准确率。

由于梯度提升算法是一种串行算法,所以它的训练速度可能会比较慢,我们以一个实际的例子来介绍:

假设我们有一个特征集Xi和值Yi,要计算y的最佳估计

我们从y的平均值开始

每一步我们都想让F_m(x)更接近y|x。

在每一步中,我们都想要F_m(x)一个更好的y给定x的近似。

首先,我们定义一个损失函数

然后,我们向损失函数相对于学习者Fm下降最快的方向前进:

因为我们不能为每个x计算y,所以不知道这个梯度的确切值,但是对于训练数据中的每一个x_i,梯度完全等于步骤m的残差:r_i!

所以我们可以用弱回归树h_m来近似梯度函数g_m,对残差进行训练:

然后,我们更新学习器

这就是梯度提升,我们不是使用损失函数相对于当前学习器的真实梯度g_m来更新当前学习器F_{m},而是使用弱回归树h_m来更新它。

也就是重复下面的步骤

1、计算残差:

2、将回归树h_m拟合到训练样本及其残差(x_i, r_i)上

3、用步长\alpha更新模型

看着很复杂对吧,下面我们可视化一下这个过程就会变得非常清晰了

决策过程可视化

这里我们使用sklearn的moons 数据集,因为这是一个经典的非线性分类数据

 import numpy as npimport sklearn.datasets as dsimport pandas as pdimport matplotlib.pyplot as pltimport matplotlib as mplfrom sklearn import treefrom itertools import product,isliceimport seaborn as snsmoonDS = ds.make_moons(200, noise = 0.15, random_state=16)moon = moonDS[0]color = -1*(moonDS[1]*2-1)df =pd.DataFrame(moon, columns = ['x','y'])df['z'] = colordf['f0'] =df.y.mean()df['r0'] = df['z'] - df['f0']df.head(10)

让我们可视化数据:

下图可以看到,该数据集是可以明显的区分出分类的边界的,但是因为他是非线性的,所以使用线性算法进行分类时会遇到很大的困难。

那么我们先编写一个简单的梯度增强模型:

 def makeiteration(i:int):"""Takes the dataframe ith f_i and r_i and approximated r_i from the features, then computes f_i+1 and r_i+1"""clf = tree.DecisionTreeRegressor(max_depth=1)clf.fit(X=df[['x','y']].values, y = df[f'r{i-1}'])df[f'r{i-1}hat'] = clf.predict(df[['x','y']].values)eta = 0.9df[f'f{i}'] = df[f'f{i-1}'] + eta*df[f'r{i-1}hat']df[f'r{i}'] = df['z'] - df[f'f{i}']rmse = (df[f'r{i}']**2).sum()clfs.append(clf)rmses.append(rmse)

上面代码执行3个简单步骤:

将决策树与残差进行拟合:

 clf.fit(X=df[['x','y']].values, y = df[f'r{i-1}'])df[f'r{i-1}hat'] = clf.predict(df[['x','y']].values)

然后,我们将这个近似的梯度与之前的学习器相加:

 df[f'f{i}'] = df[f'f{i-1}'] + eta*df[f'r{i-1}hat']

最后重新计算残差:

 df[f'r{i}'] = df['z'] - df[f'f{i}']

步骤就是这样简单,下面我们来一步一步执行这个过程。

第1次决策

Tree Split for 0 and level 1.563690960407257

第2次决策

Tree Split for 1 and level 0.5143677890300751

第3次决策

Tree Split for 0 and level -0.6523728966712952

第4次决策

Tree Split for 0 and level 0.3370491564273834

第5次决策

Tree Split for 0 and level 0.3370491564273834

第6次决策

Tree Split for 1 and level 0.022058885544538498

第7次决策

Tree Split for 0 and level -0.3030575215816498

第8次决策

Tree Split for 0 and level 0.6119407713413239

第9次决策

可以看到通过9次的计算,基本上已经把上面的分类进行了区分

我们这里的学习器都是非常简单的决策树,只沿着一个特征分裂!但整体模型在每次决策后边的越来越复杂,并且整体误差逐渐减小。

 plt.plot(rmses)

这也就是上图中我们看到的能够正确区分出了大部分的分类

如果你感兴趣可以使用下面代码自行实验:

https://avoid.overfit.cn/post/533a0736b7554ef6b8464a5d8ba964ab

作者:Tanguy Renaudie


http://www.ppmy.cn/news/29889.html

相关文章

代码随想录中:回溯算法的基础

回溯算法是一种暴力的搜索方式;回溯法一般与递归同时存在。 回溯法,一般可以解决如下几种问题: 组合问题:N个数里面按一定规则找出k个数的集合切割问题:一个字符串按一定规则有几种切割方式子集问题:一个…

LeetCode 236.二叉树的最近公共祖先

给定一个二叉树, 找到该树中两个指定节点的最近公共祖先。百度百科中最近公共祖先的定义为:“对于有根树 T 的两个节点 p、q,最近公共祖先表示为一个节点 x,满足 x 是 p、q 的祖先且 x 的深度尽可能大(一个节点也可以是它自己的祖…

大型三甲医院云HIS系统源码 强大的电子病历+完整文档

医院HIS系统源码云HIS系统:SaaS运维平台多医院入驻强大的电子病历完整文档 有源码,有演示 一、系统概述 采用主流成熟技术,软件结构简洁、代码规范易阅读,SaaS应用,全浏览器访问前后端分离,多服务协同&am…

刷题记录:牛客NC13950 Alliances 到树上联通点集的最短距离

传送门:牛客 题目描述: 题目较长,此处省略 输入: 7 1 2 1 3 2 4 2 5 3 6 3 7 2 2 6 7 1 4 3 5 1 2 1 1 1 5 2 1 2 输出: 2 1 1一道比较复杂的树题.需要一些复杂的讨论以及LCA知识 对于LCA,可以使用树链剖分进行解决 然后我们看一下题目,我们会发现有这样一个简单的结论,那就…

华为OD机试Golang解题 - 火星文计算 2 | 包含思路

华为Od必看系列 华为OD机试 全流程解析+经验分享,题型分享,防作弊指南)华为od机试,独家整理 已参加机试人员的实战技巧华为od 2023 | 什么是华为od,od 薪资待遇,od机试题清单华为OD机试真题大全,用 Python 解华为机试题 | 机试宝典文章目录 华为Od必看系列使用说明本期题目…

English Learning - L2 第1次小组纠音 [ɑː] [ɔː] [uː] 2023.2.25 周六

English Learning - L2 第1次小组纠音 [ɑː] [ɔː] [uː] 2023.2.25 周六共性问题分析大后元音 [ɑː]大后元音 [ɔː]后元音 [uː]我的发音问题后元音 [uː]大后元音 [ɑː] 和 [ɔː]纠音过程第一次第二次第三次共性问题分析 大后元音 [ɑː] 嘴唇过于松散,没…

OSPF与BFD联动配置

13.1.1BFD概念 BFD提供了一个通用的、标准化的、介质无关的、协议无关的快速故障检测机制,有以下两大优点: 对相邻转发引擎之间的通道提供轻负荷、快速故障检测。 用单一的机制对任何介质、任何协议层进行实时检测。 BFD是一个简单的“Hello”协议。两个系统之间建立BFD会…

5天带你读完《Effective Java》(四)

《Effective Java》是Java开发领域无可争议的经典之作,连Java之父James Gosling都说: “如果说我需要一本Java编程的书,那就是它了”。它为Java程序员提供了90个富有价值的编程准则,适合对Java开发有一定经验想要继续深入的程序员…