3 决策树及Python实现

news/2024/11/24 19:27:28/

1 主要思想

1.1 数据

在这里插入图片描述

1.2 训练和使用模型

训练:建立模型(树)
测试:使用模型(树)
在这里插入图片描述
Weka演示ID3(终端用户模式)

  • 双击weka.jar
  • 选择Explorer
  • 载入weather.arff
  • 选择trees–>ID3
  • 构建树,观察结果

建立决策树流程

  • Step 1. 选择一个属性
  • Step 2. 将数据集分成若干子集
  • Step 3.1 对于决策属性值唯一的子集, 构建叶结点
  • Step 3.2 对于决策属性值不唯一的子集, 递归调用本函数

演示: 利用txt文件, 按照决策树的属性划分数据集

2 信息熵

问题: 使用哪个属性进行数据的划分?
随机变量YYY的信息熵为 (YYY为决策变量):
H(Y)=E[I(yi)]=∑i=1np(yi)log⁡1p(yi)=−∑i=1np(yi)log⁡p(yi),H(Y) = E[I(y_i)] = \sum_{i=1}^n p(y_i)\log \frac{1}{p(y_i)} = - \sum_{i=1}^n p(y_i)\log p(y_i), H(Y)=E[I(yi)]=i=1np(yi)logp(yi)1=i=1np(yi)logp(yi),
其中 0log⁡0=00 \log 0 = 00log0=0.
随机变量YYY关于XXX的条件信息熵为(XXX为条件变量):
H(Y∣X)=∑i=1mp(xi)H(Y∣X=xi)=−∑i,jp(xi,yj)log⁡p(yj∣xi).\begin{array}{ll} H(Y | X) & = \sum_{i=1}^m p(x_i) H(Y | X = x_i)\\ & = - \sum_{i, j} p(x_i, y_j) \log p(y_j | x_i). \end{array} H(YX)=i=1mp(xi)H(YX=xi)=i,jp(xi,yj)logp(yjxi).
XXXYYY带来的信息增益: H(Y)−H(Y∣X)H(Y) - H(Y | X)H(Y)H(YX).

3 程序分析

版本1. 使用sklearn (调包侠)
这里使用了数据集是数值型。

import numpy as np
import scipy as sp
import time, sklearn, math
from sklearn.model_selection import train_test_split
import sklearn.datasets, sklearn.neighbors, sklearn.tree, sklearn.metricsdef sklearnDecisionTreeTest():#Step 1. Load the datasettempDataset = sklearn.datasets.load_breast_cancer()x = tempDataset.datay = tempDataset.target# Split for training and testingx_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.2)#Step 2. Build classifiertempClassifier = sklearn.tree.DecisionTreeClassifier(criterion='entropy')tempClassifier.fit(x_train, y_train)#Step 3. Test#precision, recall, thresholds = sklearn.metrics.precision_recall_curve(y_test, tempClassifier.predict(x_test))tempAccuracy = sklearn.metrics.accuracy_score(y_test, tempClassifier.predict(x_test))tempRecall = sklearn.metrics.recall_score(y_test, tempClassifier.predict(x_test))#Step 4. Outputprint("precision = {}, recall = {}".format(tempAccuracy, tempRecall))sklearnDecisionTreeTest()

版本2. 自己重写重要函数

  1. 信息熵
#计算给定数据集的香农熵
def calcShannonEnt(paraDataSet):numInstances = len(paraDataSet)labelCounts = {}	#定义空字典for featVec in paraDataSet:currentLabel = featVec[-1]if currentLabel not in labelCounts.keys():labelCounts[currentLabel] = 0labelCounts[currentLabel] += 1shannonEnt = 0.0for key in labelCounts:prob = float(labelCounts[key])/numInstancesshannonEnt -= prob * math.log(prob, 2) #以2为底return shannonEnt
  1. 划分数据集
#dataSet 是数据集,axis是第几个特征,value是该特征的取值。
def splitDataSet(dataSet, axis, value):resultDataSet = []for featVec in dataSet:if featVec[axis] == value:#当前属性不需要reducedFeatVec = featVec[:axis]reducedFeatVec.extend(featVec[axis+1:])resultDataSet.append(reducedFeatVec)return resultDataSet
  1. 选择最好的特征划分
#该函数是将数据集中第axis个特征的值为value的数据提取出来。
#选择最好的特征划分
def chooseBestFeatureToSplit(dataSet):#决策属性不算numFeatures = len(dataSet[0]) - 1baseEntropy = calcShannonEnt(dataSet)bestInfoGain = 0.0bestFeature = -1for i in range(numFeatures):#把第i列属性的值取出来生成一维数组featList = [example[i] for example in dataSet]#剔除重复值uniqueVals = set(featList)newEntropy = 0.0for value in uniqueVals:subDataSet = splitDataSet(dataSet, i, value)prob = len(subDataSet) / float(len(dataSet))newEntropy += prob*calcShannonEnt(subDataSet)infoGain = baseEntropy - newEntropyif(infoGain > bestInfoGain):bestInfoGain = infoGainbestFeature = ireturn bestFeature
  1. 构建叶节点
#如果剩下的数据中无特征,则直接按最大百分比形成叶节点
def majorityCnt(classList):classCount = {}for vote in classList:if vote not in classCount.keys():classCount[vote] = 0classCount += 1;sortedClassCount = sorted(classCount.iteritems(), key = operator.itemgette(1), reverse = True)return sortedClassCount[0][0]
  1. 创建决策树
#创建决策树
def createTree(dataSet, paraFeatureName):featureName = paraFeatureName.copy()classList = [example[-1] for example in dataSet]#Already pureif classList.count(classList[0]) == len(classList):return classList[0]#No more attributeif len(dataSet[0]) == 1:#if len(dataSet) == 1:return majorityCnt(classList)bestFeat = chooseBestFeatureToSplit(dataSet)#print(dataSet)#print("bestFeat:", bestFeat)bestFeatureName = featureName[bestFeat]myTree = {bestFeatureName:{}}del(featureName[bestFeat])featvalue = [example[bestFeat] for example in dataSet]uniqueVals = set(featvalue)for value in uniqueVals:subfeatureName = featureName[:]myTree[bestFeatureName][value] = createTree(splitDataSet(dataSet, bestFeat, value), subfeatureName)return myTree
  1. 分类和返回预测结果
#Classify and return the precision
def id3Classify(paraTree, paraTestingSet, featureNames, classValues):tempCorrect = 0.0tempTotal = len(paraTestingSet)tempPrediction = classValues[0]for featureVector in paraTestingSet:print("Instance: ", featureVector)tempTree = paraTreewhile True:for feature in featureNames:try:tempTree[feature]splitFeature = featurebreakexcept:i = 1 #Do nothingattributeValue = featureVector[featureNames.index(splitFeature)]print(splitFeature, " = ", attributeValue)tempPrediction = tempTree[splitFeature][attributeValue]if tempPrediction in classValues:breakelse:tempTree = tempPredictionprint("Prediction = ", tempPrediction)if featureVector[-1] == tempPrediction:tempCorrect += 1return tempCorrect/tempTotal
  1. 构建测试代码
def mfID3Test():#Step 1. Load the datasetweatherData = [['Sunny','Hot','High','FALSE','N'],['Sunny','Hot','High','TRUE','N'],['Overcast','Hot','High','FALSE','P'],['Rain','Mild','High','FALSE','P'],['Rain','Cool','Normal','FALSE','P'],['Rain','Cool','Normal','TRUE','N'],['Overcast','Cool','Normal','TRUE','P'],['Sunny','Mild','High','FALSE','N'],['Sunny','Cool','Normal','FALSE','P'],['Rain','Mild','Normal','FALSE','P'],['Sunny','Mild','Normal','TRUE','P'],['Overcast','Mild','High','TRUE','P'],['Overcast','Hot','Normal','FALSE','P'],['Rain','Mild','High','TRUE','N']]featureName = ['Outlook', 'Temperature', 'Humidity', 'Windy']classValues = ['P', 'N']tempTree = createTree(weatherData, featureName)print(tempTree)#print(createTree(mydata, featureName))#featureName = ['Outlook', 'Temperature', 'Humidity', 'Windy']print("Before classification, feature names = ", featureName)tempAccuracy = id3Classify(tempTree, weatherData, featureName, classValues)print("The accuracy of ID3 classifier is {}".format(tempAccuracy))def main():sklearnDecisionTreeTest()mfID3Test()main()

4 讨论

符合人类思维的模型;
信息增益只是一种启发式信息;
与各个属性值“平行”的划分。

其它决策树:

  • C4.5:处理数值型数据
  • CART:使用gini指数

http://www.ppmy.cn/news/28720.html

相关文章

JAVA开发(JAVA垃圾回收的几种常见算法)

JAVA GC 是JAVA虚拟机中的一个系统或者说是一个服务,专门是用于内存回收,交还给虚拟机的功能。 JAVA语言相对其他语言除了跨平台性,还有一个最重要的功能是JAVA语言封装了对内存的自动回收。俗称垃圾回收器。所以有时候我们不得不承认&#…

【ElasticSearch系列-01】初识以及安装elasticSearch

elasticSearch入门和安装一,elasticSearch入门1,什么是elasticSearch2,elasticSearch的底层优点2.1,全文检索2.2,倒排索引2.2.1,正排索引2.2.2,倒排索引2.2.3,倒排索引解决的问题2.2…

mysql数据库常见面试题

慢查询排查优化 排查 slow_query_log设置为on,就会记录慢查询sql;long_query_time可以设置慢查询sql的阈值时间;slow_query_log_file表示记录慢查询sql的日志路径。即我们可以通过打开记录慢查询的开关,设置慢查询的时间阈值&…

sql-labs-Less1

靶场搭建好了,访问题目路径 http://127.0.0.1/sqli-labs-master/Less-1/ 我最开始在做sql-labs靶场的时候很迷茫,不知道最后到底要得到些什么,而现在我很清楚,sql注入可以获取数据库中的信息,而获取信息就是我们的目标…

电脑系统崩溃怎么修复教程

系统崩溃了怎么办? 如今的软件是越来越复杂、越来越庞大。由系统本身造成的崩溃即使是最简单的操作,比如关闭系统或者是对BIOS进行升级都可能会对PC合操作系统造成一定的影响。下面一起来看看电脑系统崩溃修复方法步骤。 工具/原料: 系统版本&#xf…

数据结构与算法(六):图结构

图是一种比线性表和树更复杂的数据结构,在图中,结点之间的关系是任意的,任意两个数据元素之间都可能相关。图是一种多对多的数据结构。 一、基本概念 图(Graph)是由顶点的有穷非空集合和顶点之间边的集合组成&#x…

Java-Springboot整合支付宝接口

文章目录一、创建支付宝沙箱二、使用内网穿透 nat app三、编写java程序四、访问一、创建支付宝沙箱 跳转 : 支付宝沙箱平台 1、进入控制台 2、创建小程序,编写名称和绑定商家即可 3、返回第一个页面,往下滑进入沙箱 4、进行相关的配置&a…

【Linux】零成本在家搭建自己的私人服务器解决方案

我这个人自小时候以来就特喜欢永久且免费的东西,也因此被骗过(花巨款买了永久超级会员最后就十几天)。 长大后骨子里也是喜欢永久且免费的东西,所以我不买服务器,用GitHubPage或者GiteePage搭建自己的静态私人博客&…