吴恩达471机器学习入门课程2第4周——决策树

news/2024/11/30 6:03:12/

决策树

    • 1 导包
    • 2 问题描述
    • 3 one-hot 编码数据集
    • 4决策树
      • 4.1 计算熵
      • 4.2 划分数据集
    • 5 构建树

从头开始实现决策树,并将其应用于对蘑菇是可食用还是有毒的分类任务。

1 导包

import numpy as np
import matplotlib.pyplot as plt
from public_tests import *
%matplotlib inline

2 问题描述

假设您正在创办一家种植和销售野生蘑菇的公司。 由于并非所有蘑菇都可以食用,因此您希望能够根据其物理属性来判断给定的蘑菇是可食用的还是有毒的 您有一些可用于此任务的现有数据。 你能用这些数据来帮助你确定哪些蘑菇可以安全销售吗? 注意:使用的数据集仅用于说明目的。它并不意味着作为识别食用蘑菇的指南。

3 one-hot 编码数据集

Brown CapTapering Stalk ShapeSolitaryEdible
1111
1011
1000
1000
1111
0110
0000
1011
0101
1000
因此,
  • X_train 包含每个样本的三个特征

    • 帽子颜色(值为 1 表示棕色帽子,值为 0 表示红色帽子)
    • 茎形缩小(值为 1 表示“锥形茎形”,值为 0 表示“扩大”茎形)
    • 单独(值为 1 表示“是”,值为 0 表示“否”)
  • y_train 是蘑菇是否可食用

    • y = 1 表示可食用
    • y = 0 表示有毒
X_train = np.array([[1,1,1],[1,0,1],[1,0,0],[1,0,0],[1,1,1],[0,1,1],[0,0,0],[1,0,1],[0,1,0],[1,0,0]])
y_train = np.array([1,1,0,0,1,0,0,1,1,0])
X_train[:5]
array([[1, 1, 1],[1, 0, 1],[1, 0, 0],[1, 0, 0],[1, 1, 1]])

4决策树

在这个实践实验中,您将基于提供的数据集构建一个决策树。

  • 回想一下构建决策树的步骤:

    • 从根节点开始使用所有的样本
    • 计算基于所有可能特征分割时的信息增益,并选择具有最高信息增益的特征
    • 根据所选特征对数据集进行分割,并创建树的左右分支
    • 持续重复分裂过程直到满足停止标准
  • 在本实验中,您将实现以下函数,以便使用具有最高信息增益的特征将节点分成左右两个分支:

    • 计算节点上的熵
    • 基于给定特征在一个节点上将数据集分为左右两个分支
    • 计算在给定特征上分裂时的信息增益
    • 选择最大化信息增益的特征
  • 然后我们将使用您实现的辅助函数通过重复分裂过程来构建决策树,直到满足停止标准为止。

    • 对于本实验,我们选择的停止标准是设置最大深度为2。

4.1 计算熵

首先,您需要编写一个名为compute_entropy的帮助函数,用于计算节点上的熵(杂质度量)。

  • 函数接受一个numpy数组(y),该数组指示该节点中的示例是否可食用(1)或有毒(0)

请完成下面的compute_entropy()函数以:

  • 计算 p 1 p_1 p1,它是可食用示例(即在y中具有值= 1)的比例
  • 然后计算熵

H ( p 1 ) = − p 1 log 2 ( p 1 ) − ( 1 − p 1 ) log 2 ( 1 − p 1 ) H(p_1) = -p_1 \text{log}_2(p_1) - (1- p_1) \text{log}_2(1- p_1) H(p1)=p1log2(p1)(1p1)log2(1p1)

  • 注意
    • 对数使用基数为 2 2 2
    • 为了实现方便, 0 log 2 ( 0 ) = 0 0\text{log}_2(0) = 0 0log2(0)=0。也就是说,如果p_1 = 0 或者 p_1 = 1,则将熵设为0
    • 确保检查节点上的数据不是空的(即len(y) != 0),如果为空则返回0
def compute_entropy(y):entropy = 0if len(y) != 0:p1 = len(y[y==1])/len(y)if p1 != 0 and p1 !=1:entropy = -p1 * np.log2(p1)-(1-p1)*np.log2(1-p1)else:entropy = 0return entropy

4.2 划分数据集

下一步,您将编写一个名为 split_dataset 的帮助函数,该函数获取节点处的数据和要拆分的特征,并将其拆分为左右分支。稍后在实验室中,您将实现计算拆分效果如何的代码。

  • 该函数获取训练数据、该节点上数据点的索引列表以及要拆分的特征。
  • 它将数据拆分并返回左右分支的索引子集。
  • 例如,假设我们从根节点开始(因此 node_indices = [0,1,2,3,4,5,6,7,8,9]),并选择拆分特征为 0,即示例是否具有棕色帽子。
    • 然后函数的输出是,left_indices = [0,1,2,3,4,7,9]right_indices = [5,6,8]
索引棕色帽子收缩柄形状独立可食用
01111
11011
21000
31000
41111
50110
60000
71011
80101
91000
def split_dataset(X,node_indices,feature):Lnode_indices = []Rnote_indices = []for i in node_indices:if X[i][feature] ==1:Lnode_indices.append(i)else:Rnote_indices.append(i)return Lnode_indices,Rnote_indices

接下来,您将编写一个名为 information_gain 的函数,该函数接受训练数据、节点上的索引和要拆分的特征,并返回从拆分中获得的信息增益。

请完成下面所示的 compute_information_gain() 函数以计算

信息增益 = H ( p 1 node ) − ( w left H ( p 1 left ) + w right H ( p 1 right ) ) \text{信息增益} = H(p_1^\text{node})- (w^{\text{left}}H(p_1^\text{left}) + w^{\text{right}}H(p_1^\text{right})) 信息增益=H(p1node)(wleftH(p1left)+wrightH(p1right))

其中:

  • H ( p 1 node ) H(p_1^\text{node}) H(p1node) 是节点的熵
  • H ( p 1 left ) H(p_1^\text{left}) H(p1left) H ( p 1 right ) H(p_1^\text{right}) H(p1right) 是拆分后左、右分支的熵
  • w left w^{\text{left}} wleft w right w^{\text{right}} wright 分别是左、右分支中例子的比例

注意:

  • 您可以使用上面实现的 compute_entropy() 函数来计算熵
  • 我们提供了一些起始代码,该代码使用您之前实现的 split_dataset() 函数拆分数据集
def compute_information_gain(X,y,node_indices,feature):L_indices,R_indices = split_dataset(X,node_indices,feature)X_node, y_node = X[node_indices],y[node_indices]X_left, y_left = X[L_indices], y[L_indices]X_right, y_right = X[R_indices], y[R_indices]node_entropy = compute_entropy(y_node)L_entropy = compute_entropy(y_left)R_entropy = compute_entropy(y_right)left_w = len(X_left)/len(X_node)rigth_w = len(X_right)/len(X_node)w_entropy = left_w*L_entropy +rigth_w*R_entropyinformation_gain = node_entropy - w_entropyreturn information_gain

现在,让我们编写一个函数来获取最佳划分特征,方法是计算每个特征的信息增益,就像上面做的那样,并返回给出最大信息增益的特征。

请完成下面所示的 get_best_split() 函数。

  • 函数接受训练数据以及节点处数据点的索引
  • 函数的输出是提供最大信息增益的特征
    • 您可以使用 compute_information_gain() 函数遍历特征并为每个特征计算信息量
def get_best_split(X,y,node_indices):num = X.shape[1]best_feature = -1max_info_gain = 0for feature in range(num):info_gain = compute_information_gain(X,y,node_indices,feature)if info_gain > max_info_gain:max_info_gain = info_gainbest_feature = featurereturn best_feature

5 构建树

在本节中,我们使用您在上面实现的函数来生成决策树,方法是依次选择要拆分的最佳特征,直到达到停止条件(最大深度为 2)。

tree = []
root_indices = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
def build_tree_recursive(X, y, node_indices, branch_name, max_depth, current_depth):if current_depth == max_depth:formatting = " "*current_depth + "-"*current_depthprint(formatting, "%s leaf node with indices" % branch_name, node_indices)returnbest_feature = get_best_split(X, y, node_indices)tree.append((current_depth, branch_name, best_feature, node_indices))formatting = "-"*current_depthprint("%s Depth %d, %s: Split on feature: %d" % (formatting, current_depth, branch_name, best_feature))# Split the dataset at the best featureleft_indices, right_indices = split_dataset(X, node_indices, best_feature)# continue splitting the left and the right child. Increment current depthbuild_tree_recursive(X, y, left_indices, "Left", max_depth, current_depth+1)build_tree_recursive(X, y, right_indices, "Right", max_depth, current_depth+1)
build_tree_recursive(X_train, y_train, root_indices, "Root", max_depth=2, current_depth=0)
 Depth 0, Root: Split on feature: 2
- Depth 1, Left: Split on feature: 0-- Left leaf node with indices [0, 1, 4, 7]-- Right leaf node with indices [5]
- Depth 1, Right: Split on feature: 1-- Left leaf node with indices [8]-- Right leaf node with indices [2, 3, 6, 9]

http://www.ppmy.cn/news/371555.html

相关文章

参考代码XALTS1000UK01 U盾激活

XALTS1000UK01 U盾激活不了解决办法 安全中心,设备管理,解绑设备,卸载建行,重新安装,登录

建行tendyronU盾 插入电脑突然不能自动跳转IE跳出登录密码框

环境: 联想E14笔记本 Windows 10 专业版 64位 建行天地融网盾 IE浏览器10.0 问题描述: 建行tendyronU盾 插入电脑突然不能自动跳转IE跳出登录密码框,只跳出一个安装图标 解决方案: 1.前往正常可以打开U盾的电脑&#xff…

工行U盾检测不到,Giesecke Devrient StarKey无法安装 解决办法

工行U盾正确安装驱动后,插上u盾,就显示“Giesecke & Devrient StarKey”无法安装问题,支付时也检测不到U盾,说实话,工行在U盾方面处理的也实在弱智,很不好用。下面是本人找到的一个解决办法&#xff0c…

u盾如何在计算机上使用方法,u盾在电脑中具体使用操作过程

先删除驱动再重新安装(最好是最新的),装好驱动以后装入U盾,要把工行打开,然后从工行下你的个人信息到U盾上,它会有提示的,然后叫你输入密码,输入密码就可以了,以后付款插入U盾就行了&#xff0c…

Windows11下Edge浏览器登录工行农行并使用K宝U盾

随着Window10操作系统停止支持,Win11已经开始了普及。经过更换系统,对 默认的Edge浏览器进些配置,目前工商银行、农业银行的电子银行,已经可以通过Edge,以兼容IE的方式,正常登录和使用。 从支持性上看&…

工行u盾显示316_企业用户,插U盾点击U盾登录后,显示“无法显示该网页”

2个方法 请按照以下方法尝试解决:1、系统重新识别U盾。请您拔出U盾,更换USB插口插入,连接U盾的状态下重启电脑,完成后重新检测;2、重装驱动程序。首先,请拔下U盾,选择“开始-程序/所有程序-工行…

农村商业银行服务器未收到证书,不及时更新“证书” 当心网银U盾失效

近年来,网上银行业务发展特别迅速。有报告称,仅2008年中国网上银行交易量就高达320.9万亿元,同比增长30.6%。截至2008年底,全国个人网银客户已达1.48亿户,较年初增长52.81%。 为确保网银交易的安全,各家商业…