spark 3.4.4 利用Spark ML中的交叉验证、管道流实现鸢尾花分类预测案例选取最优模型

news/2024/11/27 0:55:05/

前情回顾

前面的案例中,介绍了怎么基于管道流实现啊鸢尾花案例,利用逻辑斯蒂回归模型预测。详细内容步骤可以参照相应的博客内容

本案例内容

在 Spark 中使用交叉验证结合逻辑回归(Logistic Regression)以及管道流(Pipeline)实现鸢尾花案例最优模型选择的详细介绍和示例代码(以 Scala 语言为例):

知识点介绍

1. 管道流(Pipeline)概述

在 Spark ML 中,管道流是一种用于将多个数据处理和机器学习阶段组合在一起的机制,使得整个机器学习流程更加清晰、易于管理和复用。一个典型的管道通常包含多个按顺序执行的阶段(Stage),比如数据预处理阶段(例如标准化数据、特征编码等)、特征选择阶段以及最终的模型训练阶段等。

例如,假设我们要处理一个包含文本特征和数值特征的数据集来进行分类任务。可能首先需要对文本特征进行词向量转换(一种特征工程操作),然后对所有特征进行标准化,最后使用分类模型进行训练。这些步骤就可以通过管道流按顺序组织起来,方便地进行整体操作。

2. 逻辑回归结合管道流及交叉验证的优势

  • 整合流程:通过管道流,可以把逻辑回归模型之前需要进行的一系列数据预处理步骤(如特征编码、归一化等)与逻辑回归模型训练本身整合在一个连贯的流程里。这样在进行交叉验证时,每次数据划分后都能自动按照统一的流程依次执行各阶段操作,确保数据处理和模型训练的一致性,避免了手动分别处理不同数据子集时可能出现的错误或不一致情况。
  • 高效与可复用:便于在不同数据集或者不同超参数调整场景下复用整个流程。对于逻辑回归的交叉验证来说,只需改变交叉验证中的参数网格(比如尝试不同的逻辑回归超参数组合)等设置,就能方便地重新运行整个包含数据预处理、模型训练和评估的流程,高效地找到最佳模型配置。

3. 交叉验证

交叉验证是一种统计方法,用于评估机器学习模型的性能。它通过将数据集分割成若干个互斥的子集,然后多次训练和测试模型来提高评估的可靠性和准确性。最常用的交叉验证方法是K折交叉验证(K-Fold Cross Validation),下面是其基本步骤:

K折交叉验证的基本步骤
准备数据:

        收集并清洗数据集,确保数据质量。
数据集划分:
        将整个数据集随机分成K个大小相等(或尽可能相等)的互斥子集,也称为“折”(fold)。K的常见值有5或10。
训练与验证:
        对于每个不同的K折:
                将其中的一个子集作为验证集。
                剩余的K-1个子集合并起来作为训练集。
                使用训练集来训练模型。
                利用验证集评估模型的表现,记录评估指标(例如准确率、F1分数、均方误差等)。
汇总评估:
        计算所有K次验证过程中的评估指标的平均值,以此作为模型性能的估计。
模型选择:
        根据交叉验证的结果,选择最佳的模型参数或者算法。
交叉验证的优点
        减少偏差:通过多次训练和测试,减少了由于特定训练/测试集的选择导致的模型性能评估的偏差。
        充分利用数据:几乎所有的数据都被用来训练和测试,从而提高了模型评估的可靠性。
        参数调整:有助于在不同参数设置之间做出更准确的比较,选择最优的模型配置。
注意事项
        在进行K折交叉验证时,重要的是确保每个折的数据分布尽量保持一致,这通常意味着要在分层的基础上进行划分,尤其是在处理分类问题时。
        如果数据集非常大,可以考虑使用留一法(Leave-One-Out, LOO)或重复K折交叉验证来增加评估的稳定性。

4. 网格搜索

网格搜索(Grid Search)是一种用于超参数优化的技术,在机器学习中广泛应用于模型选择和调优。它通过在指定的参数范围内系统地遍历所有可能的参数组合,来寻找最优的模型参数设置。以下是网格搜索的详细解释:

工作原理

  1. 定义参数网格:首先,为模型的每个超参数定义一个搜索范围,这些范围的组合构成了参数网格。例如,对于支持向量机(SVM)模型,你可能想要调整C(正则化参数)和gamma(核函数的系数)的值。

  2. 模型实例化:选择你要调优的模型类型,并实例化一个模型对象。

  3. 交叉验证:对于参数网格中的每一组参数,使用交叉验证(如k折交叉验证)来评估模型的性能。交叉验证通过将数据集划分为多个子集,并在不同的子集上训练和验证模型,来提供更可靠的性能估计。

  4. 性能评估:根据交叉验证的结果,计算每个参数组合对应的性能指标(如准确率、召回率、F1分数等)。

  5. 选择最优参数:比较所有参数组合的性能指标,选择性能最好的参数组合作为最优参数。

优点

  • 系统性:网格搜索通过遍历所有可能的参数组合,确保不会错过任何潜在的最优解。

  • 可靠性:使用交叉验证来评估模型性能,减少了过拟合的风险,并提供了更可靠的性能估计。

缺点

  • 计算成本高:当参数网格很大时,网格搜索的计算成本可能非常高,因为它需要训练多个模型并评估它们的性能。

  • 可能陷入局部最优:尽管网格搜索是系统性的,但它仍然受限于定义的参数范围。如果最优参数不在定义的范围内,网格搜索将无法找到它。

案例代码

package cn.lh.pblh123.spark2024.theorycourse.charpter9import cn.lh.pblh123.spark2024.theorycourse.charpter9.MLGMM.checkPathExistStatus
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, MulticlassClassificationEvaluator}
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorAssembler, VectorIndexer}
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.DoubleTypeobject CrossValidatorLR {def main(args: Array[String]): Unit = {// 检查命令行参数数量是否正确,确保程序正确使用if (args.length != 3) {System.err.println("Usage: <murl> <inputfile> <modelpath>")System.exit(1)}// 从命令行参数中提取变量val murl = args(0)val inputfile = args(1)val modelpath = args(2)// 创建SparkSession实例,用于数据处理和模型训练val spark = SparkSession.builder().appName(s"${this.getClass.getName}").master(murl).getOrCreate()// 加载数据val df = spark.read.option("header", true).csv(inputfile)// 显示数据样例和 schemadf.show(3, false)df.printSchema()// 数据预处理,转换特征和标签val dfDouble = df.select(col("sepal_length").cast(DoubleType), col("sepal_width").cast(DoubleType),col("petal_length").cast(DoubleType), col("petal_width").cast(DoubleType),col("species").alias("label"))dfDouble.printSchema()// 特征处理// 创建一个VectorAssembler实例,用于将多列特征组合成单一的特征向量val assembler = new VectorAssembler().setInputCols(Array("sepal_length", "sepal_width", "petal_length", "petal_width")).setOutputCol("features")// 使用VectorAssembler转换原始DataFrame,生成一个新的DataFrame,其中包含特征向量和标签列val dataFrame = assembler.transform(dfDouble).select("features", "label")// 显示转换后的DataFrame的前3行数据,以验证转换结果dataFrame.show(3, 0)// 获取标签列和特征列// 使用StringIndexer将标签列转换为索引形式,以便后续的机器学习算法能够处理val labelIndex = new StringIndexer().setInputCol("label").setOutputCol("labelIndex").fit(dataFrame)// 使用VectorIndexer对特征列进行索引,这有助于提高机器学习模型的效率和效果val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").fit(dataFrame)// 创建Logistic回归模型实例,设置标签列和特征列,以及模型的训练参数// 最大迭代次数设为100val logisticRegression = new LogisticRegression().setLabelCol("labelIndex").setFeaturesCol("indexedFeatures").setMaxIter(100)// 打印Logistic回归模型的参数,以便调试和优化println("logistricRegression parameters:\n" + logisticRegression.explainParams() + "\n")// 设置indexToString转换器val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndex.labels)// 设置逻辑回归流水线val lrpiple = new Pipeline().setStages(Array(labelIndex, featureIndexer, logisticRegression, labelConverter))// 划分训练集和测试集,利用随机种子val Array(trainingData, testData) = dataFrame.randomSplit(Array(0.7, 0.3), 1234L)trainingData.show(3, 0)testData.show(3, 0)// 创建参数网格,用于交叉验证val paraGrid = new ParamGridBuilder().addGrid(logisticRegression.elasticNetParam, Array(0.2, 0.8)).addGrid(logisticRegression.regParam, Array(0.01, 0.1, 0.3, 0.5)).build()// 打印参数网格,以便调试和优化println("paraGrid parameters:\n" + paraGrid + "\n")// 创建评估器,用于评估模型性能val evaluator = new MulticlassClassificationEvaluator().setLabelCol("labelIndex").setPredictionCol("prediction")// 创建交叉验证器,用于模型选择val crossValidator = new CrossValidator().setEstimator(lrpiple).setEvaluator(evaluator).setEstimatorParamMaps(paraGrid).setNumFolds(3).setSeed(1234L)// 训练最佳模型val CVModel = crossValidator.fit(trainingData)// 使用最佳模型进行预测val predictions = CVModel.transform(testData)// 显示预测结果predictions.select("predictedLabel", "label", "features", "probability").show(5, 0)// 评估模型性能val accuracy = evaluator.evaluate(predictions)println("Accuracy = " + accuracy)println("Test Error = " + (1.0 - accuracy))// 从交叉验证模型中提取最佳的管道模型val bestPipleModel = CVModel.bestModel.asInstanceOf[PipelineModel]// 从最佳管道模型中提取逻辑回归模型,该模型位于管道的第三个阶段val lrModel = bestPipleModel.stages(2).asInstanceOf[LogisticRegressionModel]// 打印最佳模型的参数和统计信息println("Best Model Parameters:\n" + lrModel.explainParams())println("Best Model Coefficients:\n" + lrModel.coefficientMatrix)println("Best Model Intercept:\n" + lrModel.interceptVector)println("Best Model Summary:\n" + lrModel.summary)println("Best Model Summary Accuracy:\n" + lrModel.summary.accuracy)println("Best Model Summary False Positive Rate:\n" + lrModel.summary.falsePositiveRateByLabel)println("Best Model Summary Precision:\n" + lrModel.summary.precisionByLabel)println("Best Model Summary Recall:\n" + lrModel.summary.recallByLabel)println("Best Model Summary FMeasure:\n" + lrModel.summary.fMeasureByLabel)// 计算AUCval binaryEvaluator = new BinaryClassificationEvaluator().setLabelCol("labelIndex").setRawPredictionCol("prediction").setMetricName("areaUnderROC")val auc = binaryEvaluator.evaluate(predictions)println("AUC = " + auc)// 打印Logistic回归模型的参数lrModel.explainParams()// 检查模型路径是否存在,并保存最佳模型checkPathExistStatus(modelpath)lrModel.save(modelpath)// 停止SparkSessionspark.stop()}}

运行结果

2024-11-25 16:32:54,035 WARN  [main] util.NativeCodeLoader (NativeCodeLoader.java:60) - Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
+------------+-----------+------------+-----------+-------+
|sepal_length|sepal_width|petal_length|petal_width|species|
+------------+-----------+------------+-----------+-------+
|5.1         |3.5        |1.4         |0.2        |setosa |
|4.9         |3.0        |1.4         |0.2        |setosa |
|4.7         |3.2        |1.3         |0.2        |setosa |
+------------+-----------+------------+-----------+-------+
only showing top 3 rowsroot|-- sepal_length: string (nullable = true)|-- sepal_width: string (nullable = true)|-- petal_length: string (nullable = true)|-- petal_width: string (nullable = true)|-- species: string (nullable = true)root|-- sepal_length: double (nullable = true)|-- sepal_width: double (nullable = true)|-- petal_length: double (nullable = true)|-- petal_width: double (nullable = true)|-- label: string (nullable = true)+-----------------+------+
|features         |label |
+-----------------+------+
|[5.1,3.5,1.4,0.2]|setosa|
|[4.9,3.0,1.4,0.2]|setosa|
|[4.7,3.2,1.3,0.2]|setosa|
+-----------------+------+
only showing top 3 rowslogistricRegression parameters:
aggregationDepth: suggested depth for treeAggregate (>= 2) (default: 2)
elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty (default: 0.0)
family: The name of family which is a description of the label distribution to be used in the model. Supported options: auto, binomial, multinomial. (default: auto)
featuresCol: features column name (default: features, current: indexedFeatures)
fitIntercept: whether to fit an intercept term (default: true)
labelCol: label column name (default: label, current: labelIndex)
lowerBoundsOnCoefficients: The lower bounds on coefficients if fitting under bound constrained optimization. (undefined)
lowerBoundsOnIntercepts: The lower bounds on intercepts if fitting under bound constrained optimization. (undefined)
maxBlockSizeInMB: Maximum memory in MB for stacking input data into blocks. Data is stacked within partitions. If more than remaining data size in a partition then it is adjusted to the data size. Default 0.0 represents choosing optimal value, depends on specific algorithm. Must be >= 0. (default: 0.0)
maxIter: maximum number of iterations (>= 0) (default: 100, current: 100)
predictionCol: prediction column name (default: prediction)
probabilityCol: Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities (default: probability)
rawPredictionCol: raw prediction (a.k.a. confidence) column name (default: rawPrediction)
regParam: regularization parameter (>= 0) (default: 0.0)
standardization: whether to standardize the training features before fitting the model (default: true)
threshold: threshold in binary classification prediction, in range [0, 1] (default: 0.5)
thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold (undefined)
tol: the convergence tolerance for iterative algorithms (>= 0) (default: 1.0E-6)
upperBoundsOnCoefficients: The upper bounds on coefficients if fitting under bound constrained optimization. (undefined)
upperBoundsOnIntercepts: The upper bounds on intercepts if fitting under bound constrained optimization. (undefined)
weightCol: weight column name. If this is not set or empty, we treat all instance weights as 1.0 (undefined)+-----------------+------+
|features         |label |
+-----------------+------+
|[4.4,3.0,1.3,0.2]|setosa|
|[4.4,3.2,1.3,0.2]|setosa|
|[4.6,3.1,1.5,0.2]|setosa|
+-----------------+------+
only showing top 3 rows+-----------------+------+
|features         |label |
+-----------------+------+
|[4.3,3.0,1.1,0.1]|setosa|
|[4.4,2.9,1.4,0.2]|setosa|
|[4.5,2.3,1.3,0.3]|setosa|
+-----------------+------+
only showing top 3 rowsparaGrid parameters:
[Lorg.apache.spark.ml.param.ParamMap;@7cb32ca52024-11-25 16:33:05,441 WARN  [main] blas.InstanceBuilder (InstanceBuilder.java:52) - Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
+-----------------+------+--------------+--------------------------------------------------------------+
|features         |label |predictedLabel|probability                                                   |
+-----------------+------+--------------+--------------------------------------------------------------+
|[4.3,3.0,1.1,0.1]|setosa|setosa        |[0.9757051466049288,0.024294044215834764,8.091792365185294E-7]|
|[4.4,2.9,1.4,0.2]|setosa|setosa        |[0.942414263844343,0.057581913147947784,3.823007709168387E-6] |
|[4.5,2.3,1.3,0.3]|setosa|setosa        |[0.6006671038922778,0.39930355820688734,2.9337900834715793E-5]|
|[4.9,3.6,1.4,0.1]|setosa|setosa        |[0.9914556600750227,0.008543921422201815,4.185027753908102E-7]|
|[5.0,3.0,1.6,0.2]|setosa|setosa        |[0.8901960393397007,0.10979575995270213,8.200707597259611E-6] |
+-----------------+------+--------------+--------------------------------------------------------------+
only showing top 5 rows+--------------+------+-----------------+--------------------------------------------------------------+
|predictedLabel|label |features         |probability                                                   |
+--------------+------+-----------------+--------------------------------------------------------------+
|setosa        |setosa|[4.3,3.0,1.1,0.1]|[0.9757051466049288,0.024294044215834764,8.091792365185294E-7]|
|setosa        |setosa|[4.4,2.9,1.4,0.2]|[0.942414263844343,0.057581913147947784,3.823007709168387E-6] |
|setosa        |setosa|[4.5,2.3,1.3,0.3]|[0.6006671038922778,0.39930355820688734,2.9337900834715793E-5]|
|setosa        |setosa|[4.9,3.6,1.4,0.1]|[0.9914556600750227,0.008543921422201815,4.185027753908102E-7]|
|setosa        |setosa|[5.0,3.0,1.6,0.2]|[0.8901960393397007,0.10979575995270213,8.200707597259611E-6] |
+--------------+------+-----------------+--------------------------------------------------------------+
only showing top 5 rowsAccuracy = 0.9607843137254901
Test Error = 0.03921568627450989
Best Model Parameters:
aggregationDepth: suggested depth for treeAggregate (>= 2) (default: 2)
elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty (default: 0.0, current: 0.2)
family: The name of family which is a description of the label distribution to be used in the model. Supported options: auto, binomial, multinomial. (default: auto)
featuresCol: features column name (default: features, current: indexedFeatures)
fitIntercept: whether to fit an intercept term (default: true)
labelCol: label column name (default: label, current: labelIndex)
lowerBoundsOnCoefficients: The lower bounds on coefficients if fitting under bound constrained optimization. (undefined)
lowerBoundsOnIntercepts: The lower bounds on intercepts if fitting under bound constrained optimization. (undefined)
maxBlockSizeInMB: Maximum memory in MB for stacking input data into blocks. Data is stacked within partitions. If more than remaining data size in a partition then it is adjusted to the data size. Default 0.0 represents choosing optimal value, depends on specific algorithm. Must be >= 0. (default: 0.0)
maxIter: maximum number of iterations (>= 0) (default: 100, current: 100)
predictionCol: prediction column name (default: prediction)
probabilityCol: Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities (default: probability)
rawPredictionCol: raw prediction (a.k.a. confidence) column name (default: rawPrediction)
regParam: regularization parameter (>= 0) (default: 0.0, current: 0.01)
standardization: whether to standardize the training features before fitting the model (default: true)
threshold: threshold in binary classification prediction, in range [0, 1] (default: 0.5)
thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold (undefined)
tol: the convergence tolerance for iterative algorithms (>= 0) (default: 1.0E-6)
upperBoundsOnCoefficients: The upper bounds on coefficients if fitting under bound constrained optimization. (undefined)
upperBoundsOnIntercepts: The upper bounds on intercepts if fitting under bound constrained optimization. (undefined)
weightCol: weight column name. If this is not set or empty, we treat all instance weights as 1.0 (undefined)
Best Model Coefficients:
-1.0160525354315355  2.5595568393649253   -0.9181804522594151     -1.9942533505022708  
0.46823886626877925  -1.1501375486462728  -0.0041242496510105284  -0.9534800212737365  
0.25465889870796954  -0.8280356114881423  1.0644125342983854      3.2745417344773724   
Best Model Intercept:
[4.022500795582984,3.9666625659384906,-7.9891633615214745]
Best Model Summary:
org.apache.spark.ml.classification.LogisticRegressionTrainingSummaryImpl@28377ea2
Best Model Summary Accuracy:
0.9595959595959596
Best Model Summary False Positive Rate:
[D@61c9f15d
Best Model Summary Precision:
[D@6a04b1bd
Best Model Summary Recall:
[D@329033ef
Best Model Summary FMeasure:
[D@13461194
AUC = 1.0


http://www.ppmy.cn/news/1550195.html

相关文章

Sickos1.1 详细靶机思路 实操笔记

Sickos1.1 详细靶机思路 实操笔记 免责声明 本博客提供的所有信息仅供学习和研究目的&#xff0c;旨在提高读者的网络安全意识和技术能力。请在合法合规的前提下使用本文中提供的任何技术、方法或工具。如果您选择使用本博客中的任何信息进行非法活动&#xff0c;您将独自承担…

业务架构、数据架构、应用架构和技术架构

TOGAF(The Open Group Architecture Framework)是一个广泛应用的企业架构框架&#xff0c;旨在帮助组织高效地进行架构设计和管理。 TOGAF 的核心就是由我们熟知的四大架构领域组成:业务架构、数据架构、应用架构和技术架构。 企业数字化架构设计中的最常见要素是4A 架构。 4…

java学习记录12

ArrayList方法总结 构造方法 ArrayList() 构造一个初始容量为 10 的空列表。 ArrayList(int initialCapacity) 构造一个具有指定初始容量的空列表。 实例方法 add(int index, E element) 在此list中的指定位置插入指定元素。 ArrayList<Integer> array…

SAP 零售方案 CAR 系统的介绍与研究

前言 当今时代&#xff0c;零售业务是充满活力和活力的业务领域之一。每天&#xff0c;由于销售运营和客户行为&#xff0c;它都会生成大量数据。因此&#xff0c;公司迫切需要管理数据并从中检索见解。它将帮助公司朝着正确的方向发展他们的业务。 这就是为什么公司用来处理…

Cmakelist.txt之win-c-udp-server

1.cmakelist.txt cmake_minimum_required(VERSION 3.16) ​ project(c_udp_server LANGUAGES C) ​ add_executable(c_udp_server main.c) ​ # link_directories("D:/Environment/mingw64/x86_64-w64-mingw32/lib") ​ target_link_libraries(c_udp_server wsock32…

Ubuntu24.04下的docker问题

按官网提示是可以安装成功的&#xff0c;但是curl无法使用https下载&#xff0c;会造成下述语句执行失败 # Add Dockers official GPG key: sudo apt-get update sudo apt-get install ca-certificates curl sudo install -m 0755 -d /etc/apt/keyrings sudo curl -fsSL https…

深入解析常见的设计模式

在本篇博文中&#xff0c;我们将逐个深入解析常见的设计模式&#xff0c;包括它们的目的、结构和具体示例&#xff0c;帮助你更好地理解和应用这些模式。 一、创建型模式 1. 单例模式&#xff08;Singleton&#xff09; 目的&#xff1a;确保一个类只有一个实例&#xff0c;…

前端数据可视化思路及实现案例

目录 一、前端数据可视化思路 &#xff08;一&#xff09;明确数据与目标 &#xff08;二&#xff09;选择合适的可视化图表类型 &#xff08;三&#xff09;数据与图表的绑定及交互设计 &#xff08;四&#xff09;页面布局与样式设计 二、具体案例&#xff1a;使用 Ech…