使用scikit-learn中的KNN包实现对鸢尾花数据集或者自定义数据集的的预测
KNN算法有三要素:1.K值选择;2.距离选择;3.分类规则选择。
步骤1 导入数据集
步骤2 将数据集设置标签
步骤3 设置超参数
代码
from sklearn.neighbors import KNeighborsClassifier
import numpy as nppoint1=[[2.1, 3.2],[1.8, 1.5],[2.5, 2.8],[1.5, 2.5],[2.2, 3.5],[2.8, 2.7],[1.6, 3.1]
]point2=[[3.0, 2.7],[4.5, 3.8],[5.2, 4.5],[6.1, 5.5],[5.0, 3.0],[4.0, 4.2],[3.5, 5.0]
]point3=[[7.5, 4.8],[8.0, 6.3],[7.8, 5.5],[8.5, 7.0],[6.9, 5.7],[7.2, 6.5],[7.8, 6.9]
]point_concat=np.concatenate((point1,point2,point3),axis=0)
point_concat_label=np.concatenate((np.zeros(len(point1)),np.ones(len(point2)),np.ones(len(point3))+1),axis=0)
n_neighbors=5
knn=KNeighborsClassifier(n_neighbors=n_neighbors,algorithm='kd_tree',p=2)#使用KNN训练
knn.fit(point_concat,point_concat_label)#3.决策边界,设定未知点,坐标点网络
x1=np.linspace(0,10,100)
y1=np.linspace(0,10,100)
#生成坐标点网格
x_axis,y_axis=np.meshgrid(x1,y1)x_axis_ravel=x_axis.ravel()
y_axis_ravel=y_axis.ravel()xy_axis=np.c_[x_axis_ravel,y_axis_ravel]
#4.KNN预测与绘制决策边界
knn_predict_result=knn.predict(xy_axis)
print("")#有x,y坐标 及预测结果 等高线绘制
import matplotlib.pyplot as pltfig=plt.figure(figsize=(15,10))ax=fig.add_subplot(111)ax.contour(x_axis,y_axis,knn.predict(xy_axis).reshape(x_axis.shape))
#绘出原始点
ax.scatter(point_concat[point_concat_label==0,0],point_concat[point_concat_label==0,1],color='b',marker='^')
#画散点为1的
ax.scatter(point_concat[point_concat_label==1,0],point_concat[point_concat_label==1,1],color='r',marker='*')
#画散点为2的
ax.scatter(point_concat[point_concat_label==2,0],point_concat[point_concat_label==2,1],color='y',marker='s')plt.show()
结果
改变k值结果不同