评估 机器学习 回归模型 的性能和准确度

devtools/2024/11/7 18:01:05/

      回归 是一种常用的预测模型,用于预测一个连续因变量和一个或多个自变量之间的关系。

那么,最后评估 回归模型 的性能和准确度非常重要,可以帮助我们判断模型是否有效并进行改进。

接下来,和大家分享如何评估 回归模型 的性能和准确度。

一、 评估指标

1.1 均方误差(MSE)

      均方误差(Mean Squared Error, MSE衡量的是预测值与真实值之间的平均平方差异。MSE越小,模型的预测精度越高。由于平方误差将偏差放大,因此MSE对异常值(Outliers)比较敏感。

MSE=\frac{1}{n}\sum_{i=1}^{n}\left ( y_{i}-\hat{y}_{i} \right )^{2}

  •  y_{i} 是第  i 个样本的真实值。\hat{y}_{i} 是第  i 个样本的预测值。n 是样本总数。

from sklearn.metrics import mean_squared_error# y_true 是真实值数组,y_pred 是预测值数组
mse = mean_squared_error(y_true, y_pred)
print("Mean Squared Error (MSE):", mse)

1.2 均方根误差(RMSE)

        均方根误差(Root Mean Squared Error, RMSE是MSE的平方根,具有与原数据相同的量纲(单位),因此更容易解释。它同样对异常值敏感。 

RMSE=\sqrt{\frac{1}{n}\sum_{i=1}^{n}\left ( y_{i}-\hat{y}_{i} \right )^{2}}

import numpy as nprmse = np.sqrt(mean_squared_error(y_true, y_pred))
print("Root Mean Squared Error (RMSE):", rmse)

1.3 平均绝对误差(MAE)

       平均绝对误差(Mean Absolute Error, MAE衡量的是预测值与真实值之间的平均绝对差异。相比MSE和RMSE,MAE对异常值不那么敏感。

 MAE=\frac{1}{n}\sum_{i=1}^{n} \left | y_{i}-\hat{y}_{i} \right |

from sklearn.metrics import mean_absolute_errormae = mean_absolute_error(y_true, y_pred)
print("Mean Absolute Error (MAE):", mae)

1.4. 决定系数(R²)

       决定系数衡量的是模型解释数据变异的比例。其取值范围在0到1之间,值越接近1,模型解释能力越强。如果R²为0,表示模型没有解释任何数据变异;如果R²为1,表示模型完美地解释了数据变异。 

 R^{2}=\frac{\sum_{i=1}^{n}\left ( y_{i}-\hat{y}_{i} \right )^{2}}{\sum_{i=1}^{n}\left ( y_{i}-\bar{y}_{i} \right )^{2}}

  • \bar{y}_{i}是真实值的平均值。

from sklearn.metrics import r2_scorer2 = r2_score(y_true, y_pred)
print("R² (Coefficient of Determination):", r2)

二、 评估图

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression# 生成示例数据
np.random.seed(0)
X = 2 * np.random.rand(1000, 1)
y = 4 + 3 * X + np.random.randn(1000, 1)# 拆分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)# 训练线性回归模型
model = LinearRegression()
model.fit(X_train, y_train)# 预测
y_train_pred = model.predict(X_train)
y_test_pred = model.predict(X_test)

2.1  真实值与预测值的散点图

我们可以通过散点图比较真实值与预测值,直观展示模型的预测效果。 

plt.scatter(X_test, y_test, color='black', label='Actual Values')
plt.scatter(X_test, y_test_pred, color='blue', label='Predicted Values')
plt.plot(X_test, y_test_pred, color='red', linewidth=2, label='Regression Line')
plt.xlabel('X')
plt.ylabel('y')
plt.title('Actual vs Predicted Values')
plt.legend()
plt.show()

2.2  预测误差的分布图 

 预测误差(真实值与预测值的差异)的分布图可以帮助我们了解模型误差的分布情况。

errors = y_test - y_test_predplt.hist(errors, bins=20, edgecolor='black')
plt.xlabel('Prediction Error')
plt.ylabel('Frequency')
plt.title('Distribution of Prediction Errors')
plt.show()

2.3  学习曲线 

       习曲线展示了训练误差和验证误差随训练集大小的变化情况,有助于我们诊断模型是否存在欠拟合或过拟合问题。 

from sklearn.model_selection import learning_curvetrain_sizes, train_scores, test_scores = learning_curve(model, X, y, cv=5, scoring='neg_mean_squared_error')train_scores_mean = -train_scores.mean(axis=1)
test_scores_mean = -test_scores.mean(axis=1)plt.plot(train_sizes, train_scores_mean, label='Training error')
plt.plot(train_sizes, test_scores_mean, label='Validation error')
plt.ylabel('MSE')
plt.xlabel('Training set size')
plt.title('Learning Curves')
plt.legend()
plt.show()

       以上是详细介绍如何评估 回归模型 的性能和准确度,包括各个评估指标的原理、公式推导以及在Python中的实现。

参考:

机器学习模型评估的方法总结(回归、分类模型的评估)_分类模型评估方法-CSDN博客

模型评估指标总结(预测指标、分类指标、回归指标)_常见模型误差评价指标-CSDN博客

机器学习笔记:回归模型评估指标——MAE、MSE、RMSE、MAPE、R2等 - Hider1214 - 博客园

持续更新中。。。  


http://www.ppmy.cn/devtools/132081.html

相关文章

Qt 窗口可见性 之 close函数和hide函数

close函数 基本功能 close() 方法的主要功能是关闭窗口,并触发一系列与关闭相关的事件和信号。调用此方法后,窗口将不再可见,但窗口对象本身仍然存在,并且可以被再次显示(通过调用 show() 方法)。 事件处…

Hive专栏概述

Hive专栏概述 Hive“出身名门”,是最初由Facebook公司开发的数据仓库工具。它简单且容易上手,是深入学习Hadoop技术的一个很好的切入点。专栏内容包括:Hive的安装和配置,其核心组件和架构,Hive数据操作语言&#xff0c…

场馆场地预定预约源码全开源uniapp+搭建教程

一.介绍 是一款基于ThinkPHPUniApp开发的多场馆场地预定小程序,提供运动场馆运营解决方案,适用于体育馆、羽毛球馆、兵乒球馆、篮球馆、网球馆等场馆 二.搭建环境 亲测的搭建环境: 系统环境:CentOS、 运行环境:宝…

PHP-FPM 性能配置优化

PHP-FPM 性能配置优化 4 核 8 G 服务器大约可以开启 500 个 PHP-FPM,极限吞吐量在 580 qps (Query Per Second 每秒查询数)左右。 Nginx php-fpm 是怎么工作的? php-fpm 全称是 PHP FastCGI Process Manager 的简称,…

golang学习2

下列哪个不是Go语言的关键字? A. defer B. break C. function D. var 答案:C 解析:Go语言的关键字中没有function,其他三个都是Go语言的关键字。 下列哪个是Go语言的数据类型? A. String B. Char C. Byte D. Float64 …

Ubuntu Linux 搭建邮件服务器(postfix + dovecot)

准备工作 1. 一台公网服务器(需要不被服务商限制发件收件的,也就是端口25、110、143、465、587、993、995不被限制),如有防火墙或安全组需要把这些端口开放 2. 一个域名,最好是com cn org的一级域名 3. 域名备案(如果服务器是国外的则不需要备案) 一、配置域名解析 …

Spring MVC 完整生命周期和异常处理流程图

先要明白 // 1. 用户发来请求: localhost:8080/user/1// 2. 处理器映射器(HandlerMapping)的工作 // 它会找到对应的Controller和方法 GetMapping("/user/{id}") public User getUser(PathVariable Long id) {return userService.getById(id); }// 3. 处理器适配…

FreeRTOS | 开中断与临界区(第十四天)

点击上方"蓝字"关注我们 00、上节回顾 RTOS | 那么什么是RTOS?三大操作系统?(第十四天)FreeRTOS | 原理介绍和资源get(第十四天)FreeRTOS | STM32F407 FreeRTOS移植(第十四天)FreeRTOS | 任务管理(第十四天)FreeRTOS | 内核控制函数和时间管理(第十四天)01、开关中…