训练/测试、过拟合问题

news/2024/11/25 7:30:17/

在机器学习中,我们创建模型来预测某些事件的结果,比如之前使用重量和发动机排量,预测了汽车的二氧化碳排放量

要衡量模型是否足够好,我们可以使用一种称为训练/测试的方法

训练/测试是一种测量模型准确性的方法

之所以称为训练/测试,是因为我们将数据集分为两组:训练集和测试集

80% 用于训练,20% 用于测试

使用训练集来训练模型、

使用测试集来测试模型

训练模型意味着创建模型

测试模型意味着测试模型的准确

下面是模拟的数据:我们的数据集展示了商店中的 100 位顾客及其购物习惯

import numpy
import matplotlib.pyplot as plt# 使用 `numpy.random.seed()` 函数设定种子可以确保每次生成的随机数序列是相同的
# 从而保证算法的可重复性和稳定性
numpy.random.seed(2)x = numpy.random.normal(3, 1, 100)
y = numpy.random.normal(150, 40, 100) / xplt.scatter(x, y)
plt.show()

散点图如下

x 轴表示购买前的分钟数

y 轴表示在购买上花费的金额

训练集应该是原始数据的 80% 的随机选择

测试集应该是剩余的 20%

train_x = x[:80]
train_y = y[:80]test_x = x[80:]
test_y = y[80:]

显示与训练集相同的散点图

plt.scatter(train_x, train_y)
plt.show()

如下所示

import numpy
import matplotlib.pyplot as plt
numpy.random.seed(2)x = numpy.random.normal(3, 1, 100)
y = numpy.random.normal(150, 40, 100) / xtrain_x = x[:80]
train_y = y[:80]test_x = x[80:]
test_y = y[80:]plt.scatter(train_x, train_y)
plt.show()

为了确保测试集不是完全不同,我们还要看一下测试集

plt.scatter(test_x, test_y)
plt.show()

 

 进行拟合数据集,通过数据点画一条线,我们使用 matplotlib 模块的 plott() 方法

绘制穿过数据点的多项式回归线

import numpy
import matplotlib.pyplot as plt
numpy.random.seed(2)x = numpy.random.normal(3, 1, 100)# 对应位置逐个元素相除,可以用来进行归一化、标准化等数据预处理操作
y = numpy.random.normal(150, 40, 100) / xtrain_x = x[:80]
train_y = y[:80]test_x = x[80:]
test_y = y[80:]mymodel = numpy.poly1d(numpy.polyfit(train_x, train_y, 4))# 生成 0 ~ 6 之间的100个 等差数列用于拟合曲线
myline = numpy.linspace(0, 6, 100)plt.scatter(train_x, train_y)
plt.plot(myline, mymodel(myline))
plt.show()

此结果可以支持我们对数据集拟合多项式回归的建议,即使如果我们尝试预测数据集之外的值会给我们带来一些奇怪的结果。例如:该行表明某位顾客在商店购物 6 分钟,会完成一笔价值 200 的购物。这可能是过拟合的迹象

但是 R-squared 分数呢? R-squared score很好地指示了我的数据集对模型的拟合程度

 R2,也称为 R平方(R-squared),它测量 x 轴和 y 轴之间的关系,取值范围从 0 到 1,其中 0 表示没有关系,而 1 表示完全相关

sklearn 模块有一个名为 rs_score() 的方法,该方法将帮助我们找到这种关系

在这里,我们要衡量顾客在商店停留的时间与他们花费多少钱之间的关系

import numpy
from sklearn.metrics import r2_score
numpy.random.seed(2)x = numpy.random.normal(3, 1, 100)
y = numpy.random.normal(150, 40, 100) / xtrain_x = x[:80]
train_y = y[:80]test_x = x[80:]
test_y = y[80:]mymodel = numpy.poly1d(numpy.polyfit(train_x, train_y, 4))r2 = r2_score(train_y, mymodel(train_x))print(r2)

 因此,从上面的情况来看,在训练数据方面,我们已经建立了一个不错的模型

然后,我们要使用测试数据来测试模型,以检验是否给出相同的结果

import numpy
from sklearn.metrics import r2_score
numpy.random.seed(2)x = numpy.random.normal(3, 1, 100)
y = numpy.random.normal(150, 40, 100) / xtrain_x = x[:80]
train_y = y[:80]test_x = x[80:]
test_y = y[80:]mymodel = numpy.poly1d(numpy.polyfit(train_x, train_y, 4))r2 = r2_score(test_y, mymodel(test_x))print(r2)

 结果 0.809 表明该模型也适合测试集,我们确信可以使用该模型预测未来值

如果购买客户在商店中停留 5 分钟,他/她将花费多少钱?

import numpy
from sklearn.metrics import r2_score
numpy.random.seed(2)x = numpy.random.normal(3, 1, 100)
y = numpy.random.normal(150, 40, 100) / xtrain_x = x[:80]
train_y = y[:80]test_x = x[80:]
test_y = y[80:]mymodel = numpy.poly1d(numpy.polyfit(train_x, train_y, 4))print(mymodel(5))

 该例预测客户花费了 22.88 美元,似乎与图表相对应

 


http://www.ppmy.cn/news/82527.html

相关文章

第三十六回:BottomeSheet Widget

文章目录 概念介绍使用方法示例代码 我们在上一章回中介绍了AlertDialog Widget相关的内容,本章回中将介绍 BottomSheet Widget.闲话休提,让我们一起Talk Flutter吧。 概念介绍 我们在这里说的BottomSheet是一种弹出式窗口,和上一章回中介绍的AlertDia…

MapReduce序列化【用户流量使用统计】

目录 什么是序列化和反序列化? 序列化 反序列化 为什么要序列化? 序列化的主要应用场景 MapReduce实现序列化 自定义bean对象实现Writable接口 1.实现Writable接口 2.无参构造 3.重写序列化方法 4.重写反序列化方法 5.顺序一致 6.重写toStri…

北斗GPS校时器(卫星授时器)助力桥梁监控系统建设

北斗GPS校时器(卫星授时器)助力桥梁监控系统建设 北斗GPS校时器(卫星授时器)助力桥梁监控系统建设 一、系统概述   整个采集系统分散在桥梁的各个部位。桥梁按照区域划分为若干区段,在主要几个区段中安置着信号采集机…

入门级 使用 vertx进行tcp 开发 spring boot整合vertx开发tcp

Vertx 简介 准备 软件下载 网络调试工具 创建Spring boot 项目 导入依赖 <!-- vertx tcp开发依赖 --> <dependency><groupId>io.vertx</groupId><artifactId>vertx-core</artifactId><version>4.3.1</version> <…

微信小程序-页面生命周期方法

在经过上一篇文章的介绍之后&#xff0c;我们知道了大体的生命周期在什么时候执行&#xff0c;这次主要是以代码的形式来展示一下具体的阶段执行什么生命周期方法。 首先我们编写一个代码可以从首页跳转到日志页面&#xff1a; <!--index.wxml--> <text>首页</t…

51单片机四路开关电路+限位开关

#include <reg51.h> #include <intrins.h> unsigned char tmp; void send_char(unsigned char txd); void delay(unsigned int k); sbit key1 P1^0; sbit key2 P1^1; sbit key3 P1^2; sbit key4 P1^3; sbit key21 P2^1; // 限位开关1 zgf sbit key22 P2^2…

企业选择CRM系统的三个好处

跟随着全面放开的脚步&#xff0c;国内经济正在强势复苏&#xff0c;每家企业都在抢订单、找客户&#xff0c;想要提高企业竞争力还是要借助CRM客户管理系统&#xff0c;CRM系统客户信息管理的价值有哪些&#xff1f;从哪些方面助力企业发展。 一、高效率的管理线索 1.便捷录…

【Error】Python3.7 No module named ‘_sqlite3‘ 解决方案

场景&#xff1a;docker容器运行keybert时出现错误 No module named ‘_sqlite3‘&#xff0c;是容器环境没有sqlite的库&#xff0c;如下图所示&#xff1a; 本机是能够正常导入sqlite3的&#xff0c;虚拟环境conda下也有该库。 python3.8版本的不可用于python3.7中&#xff0…