机器学习-作业2-贝叶斯网络

news/2025/1/16 5:04:01/

作业2

实现能处理连续属性的贝叶斯网络。

思路

  1. 怎么自动判断该属性是离散还是连续:计算该属性的不同值有多少个,超过10个就认为是连续,否则是离散的。
  2. 离散的:统计该类、该属性、该值的各个数量,计算概率,为后续计算做准备
  3. 连续的:统计该类、该属性的平均值、方差、标准差,为后续计算做准备
  4. 预测:
    1. 离散的:answer=whichjinmax(P(Cj)∏P(wi∣Cj))answer=which \ \ j \ \ in \ \ max(P(C_j)\prod P(w_i|C_j))answer=which  j  in  max(P(Cj)P(wiCj))
    2. 连续的:answer=whichjinmax(12πσje−(x−uj)22σj2)answer=which \ \ j \ \ in \ \ max(\frac{1}{\sqrt{2π}σ_j} e^{-\frac{(x-u_j)^2}{2σ_j^2}})answer=which  j  in  max(2πσj1e2σj2(xuj)2)
'''
6. 实现了训练、预测、输出精确度的功能
7. 实现自动判断特征是离散 还是连续,分别对离散、连续特征进行计算
8. 离散连续的判断依据:整个训练集中 该特征的特征值的数量 是否大于10?
'''
import numpy as np
import pandas as pd
from sklearn.datasets import make_blobs
from sklearn.model_selection import train_test_splitclass naive_bayes():def __init__(self):self.pi=3.1415926passdef fit(self,x,y):self.sign_feature=[]#标记特征为离散 false  连续truefor k in range(x.shape[1]):#特征kcount=x[x.columns[k]].unique()#计算特征k有多少不同的值if len(count)<10:#如果小于10个 处理为离散值 否则处理为连续值self.sign_feature.append('false')#标记该特征为离散else:self.sign_feature.append('true')#标记该特征为连续self.classes=y[y.columns[0]].unique()#统计y不同的值 self.x_avg={}#计算类j特征k的各个数值平均self.s_2={}#方差self.s={}#标准差self.class_condition_prob={}#类条件概率初始化为字典for j in  self.classes:#遍历类jfor k in range(x.shape[1]):#特征kif self.sign_feature[k]:#如果是连续值x_avg=0x_x=[]for i in range(x.shape[0]):#遍历xif y[y.columns[0]][i]==j :x_avg+=x[x.columns[k]][i]x_x.append(x[x.columns[k]][i])x_avg/=len(x_x)s_2=0for i in range(len(x_x)):s_2+=(x_x[i]-x_avg)*(x_x[i]-x_avg)s_2/=len(x_x)s=np.sqrt(s_2)self.x_avg[(j,k)]=x_avg#平均值self.s_2[(j,k)]=s_2#方差self.s[(j,k)]=s#标准差else:class_count=y[y.columns[0]].value_counts()#即每个类中样本数  类数*属于该类的样本数#类先验概率self.class_prior=class_count/len(y)  #属于该类的样本数/总样本数p_x_y=x[(y==j).values][x.columns[k]].value_counts()for i in p_x_y.index:self.class_condition_prob[(j,k,i)] = p_x_y[i]/class_count[j]def predict(self,x):ans=[]for i in range(len(x)):res=[]for j in self.classes:#遍历类 p_x_y=1for k in range(x.shape[1]):#遍历特征值if self.sign_feature[k]:#如果是连续值temp0=1/(np.sqrt(2*self.pi)*self.s[tuple([j]+[k])])temp1=-(self.x_avg[tuple([j]+[k])]-x[x.columns[k]][i])*(self.x_avg[tuple([j]+[k])]-x[x.columns[k]][i])temp2=temp1/(2*self.s_2[tuple([j]+[k])])p_x_y*=temp0*np.exp(temp2)else:p_y=self.class_prior[j]#类先验概率 p_x_y*=p_ytemp=x[x.columns[k]][i]p_x_y*=self.class_condition_prob[tuple([j]+[k]+[temp])]res.append(p_x_y)#获得每个类下的概率ans.append( self.classes[np.argmax(res)] )#  np.argmax(res) 返回最大值的索引return ansdef accuracy(self,y_hat,y):#计算准确率ans=0for i in range(len(y_hat)):if y_hat[i]==y[i]:ans+=1print(y_hat==y)#查看错误了哪个? 只错了1个return ans/len(y)#构造数据集
x,y=make_blobs(n_samples=500,  #样本数n_features=2,    #特征值数centers=2,  #类标签数cluster_std=[1.0,3.0],  #每个类别的方差shuffle=True,   #默认为Truerandom_state=8)feature1=np.random.randint(5,size=x.shape[0])x=np.column_stack((feature1,x))#增加离散的特征值x_train,x_test,y_train,y_test=train_test_split(x,y,random_state=8)x_list=['x1','x2','x3']
y_list=['y']
x_train=pd.DataFrame(x_train,columns=x_list)
x_test=pd.DataFrame(x_test,columns=x_list)
y_train=pd.DataFrame(y_train,columns=y_list)#使用列表形式model=naive_bayes()
model.fit(x_train,y_train)#训练y_hat=model.predict(x_test)#预测
print(y_hat)#输出预测结果score=model.accuracy(y_hat,y_test)
print(score)#输出准确率

http://www.ppmy.cn/news/43045.html

相关文章

react-router原理

前端路由的原理 自己来监听URL的改变&#xff0c;改变url&#xff0c;渲染不同的组件(页面)&#xff0c;但是页面不要进行强制刷新&#xff08;a元素不行&#xff09;。 hash模式&#xff0c;localhost:3000/#/abc 优势就是兼容性更好&#xff0c;在老版IE中都可以运行缺点是…

【hello Linux】Linux软件管理器yum

目录 1.Linux软件管理器yum 1.1 关于lrzsz 1.2 使用yum时的注意事项 1.3 查看软件包&#xff1a;yum list 1.4 安装软件&#xff1a;yum install 1.5 卸载软件&#xff1a;yum remove 1.6 更新yum源 1.7 实战项目 Linux&#x1f337; 1.Linux软件管理器yum 在windows系统下有应…

技术创业者必读:从验证想法到技术产品商业化的全方位解析

导语 | 技术创业之路往往充满着挑战和不确定性&#xff0c;对于初入创业领域的人来说&#xff0c;如何验证自己的创业想法是否有空间、如何选择靠谱的投资人、如何将技术产品商业化等问题都需要认真思考和解决。在「TVP 技术夜未眠」第六期直播中&#xff0c;正马软件 CTO、腾讯…

ChatGPT风口下的中外“狂飙”,一文看懂微软、谷歌、百度、腾讯、华为、字节跳动们在做什么?

毫无疑问&#xff0c;ChatGPT正成为搅动市场情绪的buzzword。 历史经历过无线电&#xff0c;半导体&#xff0c;计算机&#xff0c;移动通讯&#xff0c;互联网&#xff0c;移动互联网&#xff0c;社交媒体&#xff0c;云计算等多个时代&#xff0c;产业界也一直在寻找Next Bi…

Java开发 - MySQL主从复制初体验

前言 前面已经学到了很多知识&#xff0c;大部分也都是偏向于应用方面&#xff0c;在应用实战这条路上&#xff0c;博主一直觉得只有实战才是学习中最快的方式。今天带来主从复制给大家&#xff0c;在刚刚开始动手写的时候&#xff0c;才想到似乎忽略了一些重要的东西&#xf…

Ubuntu搭建web站点并发布公网访问【内网穿透】

文章目录前言1. 本地环境服务搭建2. 局域网测试访问3. 内网穿透3.1 ubuntu本地安装cpolar3.2 创建隧道3.3 测试公网访问4. 配置固定二级子域名4.1 保留一个二级子域名4.2 配置二级子域名4.3 测试访问公网固定二级子域名前言 网&#xff1a;我们通常说的是互联网&#xff1b;站…

蓝牙设备如何自定义UUID

如何自定义UUID 所有 BLE 自定义服务和特性必须使用 128 位 UUID 来识别&#xff0c;并且要确保基本 UUID 与 BLE 定义的基本 UUID&#xff08;00000000-0000-1000-8000-00805F9B34FB&#xff09;不一样。基本 UUID 是一个 128 位的数值&#xff0c;根据该值可定义标准UUID&am…

ACM8629 立体声50W/100W单声道I2S数字输入D类音频功放IC

概述 ACM8629 一款高度集成、高效率的双通道数字输入功放。供电电压范围在4.5V-26.4V,数字接口电源支持3.3V 。在4 欧负载&#xff0c;BTL模式下输出功率可以到250W1%THDN&#xff0c;在2欧负载&#xff0c;PBTL模式下单通道可以输出1100W 1%THDN. ACM8629采用新型PWM脉宽调制架…