MapReduce实现KNN算法分类推测鸢尾花种类

news/2024/11/22 7:23:59/

文章目录

  • 代码地址
  • 一、KNN算法简介
  • 二、KNN算法示例:推测鸢尾花种类
  • 三、MapReduce+Hadoop实现KNN鸢尾花分类:
    • 1. 实现环境
    • 2.pom.xml
  • 3.设计思路及代码
      • 1. KNN_Driver类
      • 2. MyData类
      • 3. KNN_Mapper类
    • 4. KNN_Reducer类

代码地址

https://gitcode.net/m0_56745306/knn_classifier.git

一、KNN算法简介

该部分内容参考自:https://zhuanlan.zhihu.com/p/45453761

  • KNN(K-Nearest Neighbor) 算法是机器学习算法中最基础、最简单的算法之一。它既能用于分类,也能用于回归。KNN通过测量不同特征值之间的距离来进行分类。

  • KNN算法的思想非常简单:对于任意n维输入向量,分别对应于特征空间中的一个点,输出为该特征向量所对应的类别标签或预测值。

  • 对于一个需要预测的输入向量x,我们只需要在训练数据集中寻找k个与向量x最近的向量的集合,然后把x的类别预测为这k个样本中类别数最多的那一类。
    在这里插入图片描述
    如图所示,ω1、ω2、ω3分别代表训练集中的三个类别。其中,与xu最相近的5个点(k=5)如图中箭头所指,很明显与其最相近的5个点中最多的类别为ω1,因此,KNN算法将xu的类别预测为ω1。

二、KNN算法示例:推测鸢尾花种类

鸢尾花数据集记载了三类花(Setosa,versicolor,virginica)以及它们的四种属性(花萼长度、花萼宽度、花瓣长度、花瓣宽度)。例如:

4.9,3.0,1.4,0.2,setosa
6.4,3.2,4.5,1.5,versicolor
6.0,2.2,5.0,1.5,virginica

对于给定的测试数据,我们需要根据它的四种信息判断其属于哪一种鸢尾花。并输出它的序号:
例如:

#假设该数据为第一条数据(对应序号为0)
5.7,3.0,4.2,1.2  

输出可以为:

0 setosa

三、MapReduce+Hadoop实现KNN鸢尾花分类:

1. 实现环境

  • Ubuntu20.04
  • Hadoop3.3.5
  • Java8
  • Maven3.9.1

2.pom.xml

<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"><modelVersion>4.0.0</modelVersion><groupId>org.example</groupId><artifactId>KNN_Classifier</artifactId><version>1.0-SNAPSHOT</version><packaging>jar</packaging><name>KNN_Classifier</name><url>http://maven.apache.org</url><build><plugins><plugin><groupId>org.apache.maven.plugins</groupId><artifactId>maven-shade-plugin</artifactId><version>3.1.0</version><executions><execution><phase>package</phase><goals><goal>shade</goal></goals></execution></executions><configuration><filters><filter><artifact>*:*</artifact><excludes><exclude>module-info.class</exclude><exclude>META-INF/*.SF</exclude><exclude>META-INF/*.DSA</exclude><exclude>META-INF/*.RSA</exclude></excludes></filter></filters><transformers><transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer"><!-- main()所在的类,注意修改 --><mainClass>KNN_Classifier.KNN_Driver</mainClass></transformer></transformers></configuration></plugin><plugin><groupId>org.apache.maven.plugins</groupId><artifactId>maven-compiler-plugin</artifactId><version>3.8.1</version><configuration><source>8</source><target>8</target><encoding>UTF-8</encoding></configuration></plugin></plugins></build><properties><project.build.sourceEncoding>UTF-8</project.build.sourceEncoding><java.version>17</java.version><maven.compiler.source>17</maven.compiler.source><maven.compiler.target>17</maven.compiler.target></properties><dependencies><dependency><groupId>junit</groupId><artifactId>junit</artifactId><version>4.11</version><scope>test</scope></dependency><dependency><groupId>org.apache.hadoop</groupId><artifactId>hadoop-common</artifactId><version>3.3.5</version></dependency><!-- https://mvnrepository.com/artifact/org.apache.hadoop/hadoop-hdfs --><dependency><groupId>org.apache.hadoop</groupId><artifactId>hadoop-hdfs</artifactId><version>3.3.5</version></dependency><dependency><groupId>org.apache.hadoop</groupId><artifactId>hadoop-mapreduce-client-core</artifactId><version>3.3.5</version></dependency></dependencies>
</project>

3.设计思路及代码

1. KNN_Driver类

Diriver类主要负责初始化job的各项属性,同时将训练数据加载到缓存中去,以便于Mapper读取。同时为了记录测试数据量,在conf中设置testDataNum用于在map阶段记录。

package KNN_Classifier;import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.util.GenericOptionsParser;public class KNN_Driver {public static void main(String[] args) throws Exception {Configuration conf = new Configuration();GenericOptionsParser optionParser = new GenericOptionsParser(conf, args);String[] remainingArgs = optionParser.getRemainingArgs();if (remainingArgs.length != 3) {System.err.println("Usage: KNN_Classifier <training dataset> <test dataset> <output>");System.exit(2);}conf.setInt("K",5);//设置KNN算法的K值conf.setInt("testDataNum",0);//设置全局计数器,记录测试数据数目conf.setInt("dimension",4);//设置向量维度Job job = Job.getInstance(conf, "KNN_Classifier");job.setJarByClass(KNN_Driver.class);job.setMapperClass(KNN_Mapper.class);job.setReducerClass(KNN_Reducer.class);//将训练数据添加到CacheFile中job.addCacheFile(new Path(remainingArgs[0]).toUri());FileInputFormat.addInputPath(job, new Path(remainingArgs[1]));FileOutputFormat.setOutputPath(job, new Path(remainingArgs[2]));job.waitForCompletion(true);System.exit(0);}
}

2. MyData类

这个类对每条测试数据进行封装,同时用于计算向量距离。

package KNN_Classifier;import java.util.Vector;public class MyData {//向量维度private Integer dimension;//向量坐标private Vector<Double>vec = new Vector<Double>();//属性,这里是水仙花的种类private String attr = new String();public  void setAttr(String attr){this.attr = attr;}public void setVec(Vector<Double> vec) {this.dimension = vec.size();for(Double d : vec){this.vec.add(d);}}public double calDist(MyData data1)//计算两条数据之间的欧式距离{try{if(this.dimension != data1.dimension)throw new Exception("These two vectors have different dimensions.");}catch (Exception e){System.out.println(e.getMessage());System.exit(-1);}double dist = 0;for(int i = 0;i<dimension;i++){dist += Math.pow(this.vec.get(i)-data1.vec.get(i),2);}dist = Math.sqrt(dist);return dist;}public String getAttr() {return attr;}
}

3. KNN_Mapper类

  • setup:用于加载缓存中的训练数据到Mapper的列表当中,同时读取K值、维度等必要信息。

  • readTrainingData:由setup调用,加载缓存训练数据。

  • Gaussian:用于计算欧式距离x所占权重,它的公式为:
    f ( x ) = a e ( x − b ) 2 − 2 c 2 f(x) = ae^{\frac{(x-b)^2}{-2c^2}} f(x)=ae2c2(xb)2
    它的图像为:

在这里插入图片描述

∣ x ∣ |x| x绝对值增加, f ( x ) f(x) f(x)的值越来越小,可以反映距离对权重的影响:即欧式距离越大,权重越小,对标签的影响也越小。

实际上高斯函数各个参数的确定需要对样本数据经过多次交叉验证得出,但为了简单起见,这里另a=1,b=0,c=0.9即可(这种情况下训练的结果比较好一些)。

  • map:对得到的测试数据进行KNN算法处理,它的伪代码如下:

    map(key,val): #key为样本数据偏移量,val为该行数据testData = getTestData ; #从val中读取测试数据信息K_Nearest = Empty ; #K最近邻,可以用最大堆来实现for trainingData in trainingDataSet : #遍历可以改为用KDTree优化dist = CalDist(testData,trainingData) ;if sizeof(K_Nearest) < K : #如果此时还未达到K值,直接添加K_Nearest.add(dist,trainingData.attr) ;else :if dist < K_Nearest.maxDist : #如果计算得出的距离大于当前K个点之中最大距离,则替换之replace pair with maxDist to (dist,trainingData.attr) ; calculate weight sum for every attr ; #为每种标签计算权重和write(idx,max_weight_attr); #写入序号,最大权重标签,完成分类
    

综上,下面是KNN_Mapper的代码:

package KNN_Classifier;import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Mapper;import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.*;
import java.net.URI;
import java.io.BufferedReader;
import java.io.FileReader;
import javafx.util.Pair;public class KNN_Mapper extends Mapper<LongWritable, Text, LongWritable, Text> {private Text text = new Text();//输出Val值private LongWritable longWritable = new LongWritable();//输出K值private Integer K;//K值private Configuration conf;//全局配置private Integer dimension;//维度private List<MyData> training_data = new ArrayList<>();private void readTrainingData(URI uri)//读取训练数据到training_data中{System.err.println("Read Training Data");try{Path patternsPath = new Path(uri.getPath());String patternsFileName = patternsPath.getName().toString();BufferedReader reader = new BufferedReader(new FileReader(patternsFileName));String line;Vector<Double>vec = new Vector<>();while ((line = reader.readLine()) != null) {// TODO: your code here//String[] strings = line.split(",");for(int i=0;i<dimension;i++){vec.add(Double.valueOf(strings[i]));}MyData myData = new MyData();myData.setVec(vec);myData.setAttr(strings[dimension]);System.out.println(strings[dimension]);training_data.add(myData);vec.clear();}reader.close();}catch (FileNotFoundException e){e.printStackTrace();}catch (IOException e){e.printStackTrace();}System.err.println("Read End");}private double Gaussian(double dist){//a = 1,b=0,c = 0.9,2*c^2 = 1.62double weight = Math.exp(-Math.pow(dist,2)/(1.62));return weight;}@Overridepublic void setup(Context context) throws IOException,InterruptedException {conf = context.getConfiguration();this.K = conf.getInt("K",1);this.dimension = conf.getInt("dimension",1);URI[] uri = context.getCacheFiles();readTrainingData(uri[0]);}@Overridepublic void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {String line = value.toString();try {String[] strings = line.split(",");if (strings.length!=dimension) {throw new Exception("Error line format in the table.");}//获取测试数据信息Vector<Double>vec = new Vector<>();for(String s:strings){System.err.println("S: "+s);vec.add(Double.valueOf(s));}MyData testData = new MyData();testData.setVec(vec);//计算与样本的K近邻//存放K近邻的优先级队列,元素类型为<距离,属性>PriorityQueue<Pair<Double,String>>K_nearst = new PriorityQueue<>((a,b)->(a.getKey()>b.getKey())?-1:1);double dist;for(MyData data : this.training_data){dist = testData.calDist(data);if(K_nearst.size()<this.K){K_nearst.add(new Pair<>(dist,data.getAttr()));}else{if(dist < K_nearst.peek().getKey()){K_nearst.poll();K_nearst.add(new Pair<>(dist,data.getAttr()));}}}//获取到K近邻后,通过高斯函数处理每条数据,并累加相同属性的权值,通过Hash_table实现Hashtable<String,Double>weightTable = new Hashtable<>();while(!K_nearst.isEmpty()){double d = K_nearst.peek().getKey();String attr = K_nearst.peek().getValue();double w = this.Gaussian(d);if(!weightTable.contains(attr)){weightTable.put(attr,w);}else{weightTable.put(attr,weightTable.get(attr)+w);}K_nearst.poll();}//选取权重最大的标签作为输出Double max_weight = Double.MIN_VALUE;String target_attr = "";for(Iterator<String> itr = weightTable.keySet().iterator();itr.hasNext();){String hash_key = (String)itr.next();Double hash_val = weightTable.get(hash_key);if(hash_val > max_weight){target_attr = hash_key;max_weight = hash_val;}}text.set(target_attr);//获取测试数据条数,用作下标计数longWritable.set(conf.getLong("testDataNum",0));conf.setLong("testDataNum",longWritable.get()+1);//计数加一context.write(longWritable,text);}catch (Exception e) {System.err.println(e.toString());System.exit(-1);}}
}

4. KNN_Reducer类

由于Mapper类已经完成了所有工作,所以传入到Reducer中的键值对都是Index,Attr的形式,直接写入即可。

package KNN_Classifier;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Reducer;
import java.io.IOException;public class KNN_Reducer extends Reducer<LongWritable, Text,LongWritable,Text> {public void reduce(LongWritable key, Iterable<Text> values,Context context) throws IOException, InterruptedException {for(Text val:values){context.write(key,val);}}
}

http://www.ppmy.cn/news/91220.html

相关文章

燕千云ChatGPT应用,用过的都说香

本期受访人物&#xff1a;张礼军 甄知科技联合创始人&#xff0c;CTO 首席产品官 2022年底&#xff0c;基于人工智能技术驱动的自然语言工具横空出世&#xff0c;一经推出&#xff0c;ChatGPT迅速火遍全球&#xff0c;几乎各行各业都在探索ChatGPT具体业务场景的应用&#xf…

Revit幕墙:用幕墙巧做屋面瓦及如何快速幕墙?

一、Revit中用幕墙巧做屋面瓦 屋面瓦重复性很高&#xff0c;我们如何快速的创建呢?下面我们来学会快速用幕墙来创建屋面瓦的技巧。 1.新建“公制轮廓-竖挺”族&#xff0c;以此来创建瓦的族(以便于载入项目中使用) 2.在轮廓族中绘制瓦的轮廓(轮廓需要闭合)&#xff0c;将族名称…

零基础web安全入门学习路线

相信很多新手都会遇到以下几个问题 1.零基础想学渗透怎么入手&#xff1f; 2.学习web渗透需要从哪里开始&#xff1f; 这让很多同学都处于迷茫状态而迟迟不下手&#xff0c;小编就在此贴给大家说一下web渗透的学习路线&#xff0c;希望对大家有帮助 同时本博客也会按照学习路…

【数据库】无效数据:软件测试对无效数据的处理

目录 一、无效数据的常见场景 &#xff08;1&#xff09;测试阶段 &#xff08;2&#xff09;测试方法 二、无效数据的概念 三、无效数据的影响 四、无效数据的识别 五、无效数据的处理方法 &#xff08;1&#xff09;拒绝无效数据 ① 拒绝无效数据的概念 ② 拒绝…

递归的学习

递归是一种解决计算问题的方法&#xff0c;其中解决方案取决于同一类问题的更小子集 说明&#xff1a; 1.自己调用自己&#xff0c;说过说每个函数对应着一种解决方案&#xff0c;自己调用自己意味着解决方案都是一样的 2.每次调用&#xff0c;函数处理的数据会较上次递减&a…

函数 prctl 系统调用

prctl是一个系统调用&#xff0c;用于控制和修改进程的行为和属性。它可以在Linux系统上使用&#xff0c;提供了各种功能和选项来管理进程的不同方面。 以下是prctl函数的基本原型&#xff1a; #include <sys/prctl.h>int prctl(int option, unsigned long arg2, unsig…

【MySQL】字段截取拼接修改数据

需求&#xff1a; 将数据库中的某一个字段的前6位替换成一个新的字符串&#xff0c;其它位置不变。 拼接函数&#xff1a; CONCAT(A,B)&#xff1a;将A和B拼接起来。 截取函数&#xff1a; LEFT(str,3)&#xff1a;截取str的前3位&#xff1b; select left(sqlstudy.com,…

《操作系统》by李治军 | 实验5.pre - switch_to 汇编代码详解

目录 【前言】 一、栈帧的处理 1. 什么是栈帧 2. 为什么要处理栈帧 3. 执行 switch_to 前的内核栈 4. 栈帧处理代码分析 二、PCB 的比较 1. 根据 PCB 判断进程切换与否 2. PCB 比较代码分析 三、PCB 的切换 1. 什么是 PCB 的切换 2. PCB 切换代码分析 四、TSS 内核…