机器学习-scikit-learn

news/2025/1/8 7:52:34/

文章目录

    • 前言
    • 线性回归模型-LinearRegression
    • 准备数据集
    • 使用LinearRegression
    • 总结

前言

scikit-learn是Python中最流行的机器学习库之一,它提供了各种各样的机器学习算法和工具,包括分类、回归、聚类、降维等。

scikit-learn的优点有:

  • 简单易用:scikit-learn 的接口简单易懂,可以让用户很容易地上手进行机器学习。
    统一的API:scikit-learn 的 API 非常统一,各种算法的使用方法基本一致,使得学习和使用变得更加方便。
  • 大量实现了机器学习算法:scikit-learn 实现了各种经典的机器学习算法,而且提供了丰富的工具和函数,使得算法的调试和优化变得更加容易。
  • 开源免费:scikit-learn 是完全开源的,而且是免费的,任何人都可以使用和修改它的代码。
  • 高效稳定:scikit-learn 实现了各种高效的机器学习算法,可以处理大规模数据集,并且在稳定性和可靠性方面表现出色。
    scikit-learn因为API非常的统一而且模型相对较简单所以非常适合入门机器学习。
    这里我的推荐方式是结合官方文档进行学习,不仅有每个模型的适用范围介绍还有代码样例。
    scikit-learn官网地址

线性回归模型-LinearRegression

LinearRegression模型是一种基于线性回归的模型,适用于解决连续变量的预测问题。该模型的基本思想是建立一个线性方程,将自变量与因变量之间的关系建模为一条直线,并利用训练数据拟合该直线,从而求出线性方程的系数,再用该方程对测试数据进行预测。

LinearRegression模型适用于自变量和因变量之间存在线性关系的问题,例如房价预测、销售预测、用户行为预测等。当然,当自变量和因变量之间的关系为非线性时,LinearRegression模型的表现会比较差。此时可以采用多项式回归、岭回归、Lasso回归等方法来解决。

准备数据集

在抛开其它因素影响后,学习时间和学习成绩之间存在着一定的线性关系,当然这里的学习时间指的是有效学习时间,表现为随着学习时间的增加成绩也会增加。所以我们准备一份学习时间和成绩的数据集。数据集内部分数据如下:

学习时间,分数
0.5,15
0.75,23
1.0,14
1.25,42
1.5,21
1.75,28
1.75,35
2.0,51
2.25,61
2.5,49

使用LinearRegression

  • 确定特征和目标

在学习时间和成绩间,学习时间为特征,也即自变量;成绩为标签也即因变量,所以我们需要在准备好的学习时间和成绩数据集中提取特征和标签。

import pandas as pd
import numpy as np
from sklearn.metrics import r2_score, mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression# 读取学习时间和成绩CSV数据文件
data = pd.read_csv('data/study_time_score.csv')
# 提取数据特征学习时间
X = data['学习时间']
# 提取数据目标(标签)分数
Y = data['分数']
  • 划分训练集和测试集

在特征及标签数据准备好以后,使用scikit-learn的LinearRegression进行训练,将数据集划分为训练集和测试集。

"""
将特征数据和目标数据划分为测试集和训练集
通过test_size=0.25将百分之二十五的数据划分为测试集
"""
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.25, random_state=0)x_train = X_train.values.reshape(-1, 1)model.fit(x_train, Y_train)
  • 选择模型,对数据进行拟合

将测试集和训练集准备好以后,我们就可以选择合适的模型对训练集进行拟合,以便能够预测出其它特征对应的目标

# 选择模型,选择模型为LinearRegression
model = LinearRegression()
# Scikit-learn中,机器学习模型的输入必须是一个二维数组。我们需要将一维数组转换为二维数组,才能在模型中使用。
x_train = X_train.values.reshape(-1, 1)
# 进行拟合
model.fit(x_train, Y_train)
  • 得到模型参数

由于数据集只包含学习时间和成绩两个是一个很简单的线性模型,其背后的数学公式也即y=ax+b,其中y因变量也就是成绩, x自变量也即学习时间。

"""
输出模型关键参数
Intercept: 截距 即b
Coefficients: 变量权重 即a
"""
print('Intercept:', model.intercept_)
print('Coefficients:', model.coef_)
  • 回测
    上面拟合模型只用到了测试集数据,下面我们需要使用测试集数据对模型的拟合进行一个回测,在使用训练集拟合后,我们就可以对特征测试集进行预测,通过得到的目标预测结果与实际目标的值进行比较,我们就可以得到模型的拟合度了。
# 转换为n行1列的二维数组
x_test = X_test.values.reshape(-1, 1)
# 在测试集上进行预测并计算评分
Y_pred = model.predict(x_test)
# 打印测试特征数据
print(x_test)
# 打印特征数据对应的预测结果
print(Y_pred)# 将预测结果与原特征数据对应的实际目标值进行比较,从而获得模型拟合度# R2 (R-squared):模型拟合优度,取值范围在0~1之间,越接近1表示模型越好的拟合了数据。
print("R2:", r2_score(Y_test, Y_pred))
  • 程序运行结果
  • 根据上述的代码我们需要确定LinearRegression模型的拟合度,也就是这些数据到底适合不适合使用线性模型进行拟合,程序的运行结果如下:
预测结果:
[47.43726068 33.05457106 49.83437561 63.41802692 41.84399249 37.8488009323.46611131 37.84880093 26.66226456 71.40841004 18.67188144 88.987252963.41802692 42.6430308  21.86803469 69.81033341 66.61418017 33.0545710658.62379705 50.63341392 18.67188144 41.04495418 20.26995807 77.8007165328.26034119 13.87765157 61.81995029 90.58532953 77.80071653 36.2507243184.19302303]
R2: 0.8935675710322939

总结

上述模型的拟合度达到了89%,如果你能接受大约10%的误差,则可以使用LinearRegression模型进行预测。当调整训练集大小小于25%时,模型的拟合度稍低于89%,数据集的大小和训练集的大小等因为都会影响模型的拟合度,需要不断尝试找到拟合效果的参数设定。


http://www.ppmy.cn/news/34410.html

相关文章

去中心化联邦学习思想

去中心化联邦学习是一种保护用户隐私的分散式机器学习方法。与集中式联邦学习相比,去中心化联邦学习更加注重保护用户数据隐私,同时也更具有扩展性和健壮性。 在去中心化联邦学习中,每个设备都使用本地数据进行模型训练,并将模型…

算法训练营第五十九天|LeetCode647、516

题目连接:647. 回文子串 - 力扣(LeetCode)个人思路:dp数组的含义是:dp[i][j]:s字符串下标i到下标j的字串是否是一个回文串这里我出现了错误为什么出错呢?代码如下:class Solution {p…

Leetcode27. 移除元素

目录一、题目描述:二、解决思路和代码1. 解决思路2. 代码一、题目描述: 给你一个数组 nums 和一个值 val,你需要 原地 移除所有数值等于 val 的元素,并返回移除后数组的新长度。 不要使用额外的数组空间,你必须仅使用…

ruoyi-cloud版本最新环境部署 - 详细步骤

项目地址: RuoYi-Cloud: 🎉 基于Spring Boot、Spring Cloud & Alibaba的分布式微服务架构权限管理系统,同时提供了 Vue3 的版本 1. 后端cloud版本环境搭建 jdk、mysql、maven、redis、nginx、nacos安装 安装redis(redis下…

Vue:路由管理模式

三种模式 Vue.js 的路由管理有三种模式: Hash 模式(默认):在 URL 中使用 # 符号来管理路由。例如,http://example.com/#/about。这个模式的好处是可以避免浏览器向服务器发送不必要的请求,并且不需要特殊…

深度学习:GPT1、GPT2、GPT-3

深度学习:GPT1、GPT2、GPT3的原理与模型代码解读GPT-1IntroductionFramework自监督学习微调ExperimentGPT-2IntroductionApproachConclusionGPT-3GPT-1 Introduction GPT-1(Generative Pre-training Transformer-1)是由OpenAI于2018年发布的…

Linux上用Samba建立共享文件夹并通过Linux测试

本文基于redhat 9 版本进行配置演示 一.Samba简介 二.samba挂载配置 1.服务端下载samba,samba-client,客户端下载cifs-utils 2.服务端 3.客户端 三.samba自动挂载配置 1.服务端配置不变,客户端下载autofs并开启 2.编辑配置文件 3.重…

全网最完整,接口测试总结彻底打通接口自动化大门,看这篇就够了......

目录:导读前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结(尾部小惊喜)前言 所谓接口&#xff0…