sklearn 转换器和预估器

news/2024/11/23 9:53:18/

        刚学习sklearn时,没分清转换器的fit()和模型训练的fit(),还以为是一个,结果学完了回过头来,才发现这些差异。再此记录一下。

一、 sklearn 转换器和预估器

  1. 转换器(Transformers)

    • 定义:转换器是一种可以对数据进行某种转换的对象。例如,标准化、归一化、PCA等都是转换器的例子。
    • 主要方法
      • fit(X, y=None):在数据集X上训练转换器,这可以让转换器学习数据的一些统计特性。
      • transform(X):使用学习到的转换在新的数据集X上执行转换。
      • fit_transform(X, y=None):这是fittransform的组合。它首先在X上训练转换器,然后在同一个数据上执行转换。
    • 用途:转换器主要用于数据预处理,比如缺失值填充、特征缩放、编码分类特征等。
  2. 预估器(Estimators)

    • 定义:预估器是一种可以估计某些参数的对象。在sklearn中,几乎所有的学习算法都是预估器,包括分类、回归、聚类等。
    • 主要方法
      • fit(X, y):用数据X和标签y训练预估器。
      • predict(X):对新的数据集X进行预测。
      • score(X, y):评估预估器在数据X和标签y上的性能。
    • 用途:预估器主要用于执行实际的学习任务,如分类、回归、聚类等。

总结

  • 转换器主要用于改变数据的形式或结构,而预估器用于基于数据进行预测或决策。
  • 转换器和预估器都有fit方法,用于从数据中学习参数。
  • 转换器用于数据预处理阶段,预估器则用于模型的训练和预测阶段。

二、转换器中的fit和模型训练中的fit区别

尽管它们都被称为fit,但它们在转换器和模型(预估器)之间的作用有所不同。

  1. 转换器中的fit

    • 作用:在转换器中,fit方法用于学习数据的某些特性。例如,如果使用标准化转换器,fit方法会计算特征的均值和标准差。
    • 目的:目的是理解数据的结构和分布,以便可以将相同的转换应用于训练数据和未来的数据。
    • 输出fit方法通常不返回转换后的数据,而是将学习到的参数存储在转换器对象中。
  2. 模型训练中的fit(预估器的fit

    • 作用:在预估器(例如分类器或回归器)中,fit方法用于从带标签的数据中学习模型的参数。例如,线性回归中的fit方法将找到最佳拟合线的斜率和截距。
    • 目的:目的是找到可以用于预测未来数据的模型参数。
    • 输出:与转换器不同,预估器的fit方法将修改对象本身,使其准备好进行预测。

总的来说,转换器的fit方法与预估器的fit方法的主要区别在于它们的目标和作用:

  • 转换器的fit主要用于理解数据的特性,并为将来的转换做准备。
  • 预估器的fit用于学习模型的参数,以便对新数据进行预测。

特征工程:标准化
transfer = StandardScaler()
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)

实际中,我们经常写上述代码,fit_transform其实可以理解为先fit后transfrom,先训练出模型,然后根据模型进行转换。本质上上述代码等价于:

transfer = StandardScaler()

transfer.fit(x_train)
x_train = transfer.transform(x_train)
x_test = transfer.transform(x_test)

三、案例分析

案例:房价预测

假设有一个包含各种特征的房屋数据集,例如面积、卧室数量、地段等,以及房价的标签。目标是根据这些特征预测房价。

步骤1:数据预处理(使用转换器)

首先,需要对一些特征进行标准化,使其具有均值为0和标准差为1。

from sklearn.preprocessing import StandardScaler# 假设 X_train 是训练特征
scaler = StandardScaler()
scaler.fit(X_train)  # 用fit方法学习训练数据的均值和标准差# 使用转换器对训练和测试数据进行转换
X_train_scaled = scaler.transform(X_train)
X_test_scaled = scaler.transform(X_test)  # 注意:使用训练数据的均值和标准差转换测试数据

在这里,fit方法用于学习训练数据的均值和标准差。然后,使用transform方法将这些参数应用于训练和测试数据。

步骤2:训练模型(使用预估器)

使用预估器,例如线性回归模型,对房价进行预测。

from sklearn.linear_model import LinearRegression# 创建线性回归模型
model = LinearRegression()# 使用标准化后的训练数据和房价标签来训练模型
model.fit(X_train_scaled, y_train)# 使用训练后的模型预测测试集的房价
y_pred = model.predict(X_test_scaled)

在这里,预估器的fit方法用于从训练数据中学习模型参数,以便可以用于对新数据进行预测。

总结

  • 转换器的fit用于学习数据的特性(例如均值和标准差),以便以后可以对新数据应用相同的转换。
  • 预估器的fit用于从带标签的训练数据中学习模型参数,以便可以用于预测。

完整案例:波士顿房价预测

步骤1:加载数据

from sklearn.datasets import load_bostonboston = load_boston()
X = boston.data
y = boston.target

步骤2:划分训练集和测试集

from sklearn.model_selection import train_test_splitX_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

步骤3:创建预处理

from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipelinepreprocessing_pipeline = Pipeline([('scaler', StandardScaler()),
])
preprocessing_pipeline.fit(X_train)
X_train_scaled = preprocessing_pipeline.transform(X_train)
X_test_scaled = preprocessing_pipeline.transform(X_test)

步骤4:选择和配置预估器

from sklearn.linear_model import LinearRegressionmodel = LinearRegression()

步骤5:拟合预估器

model.fit(X_train_scaled, y_train)

步骤6:预测

y_pred = model.predict(X_test_scaled)

步骤7:评估

from sklearn.metrics import mean_squared_errormse = mean_squared_error(y_test, y_pred)
print(f'Mean Squared Error: {mse}')

参考:黑马机器学习视频。


http://www.ppmy.cn/news/1001490.html

相关文章

翻转卡片游戏(力扣)

题目 在桌子上有 n 张卡片,每张卡片的正面和背面都写着一个正数(正面与背面上的数有可能不一样)。 我们可以先翻转任意张卡片,然后选择其中一张卡片。 如果选中的那张卡片背面的数字 x 与任意一张卡片的正面的数字都不同&#…

javascript的微任务和宏任务,以及其执行顺序

一、JS是单线程的,但是为什么会有同步异步的存在呢?因为JS中有非常关键的一块,Event Loop。 Event Loop是一个程序结构,用于等待和发送消息和事件。 简单的说,就是在程序中(不一定是浏览器)中跑…

侧边栏的打开与收起

侧边栏的打开与收起 <template><div class"box"><div class"sideBar" :class"showBox ? : controller-box-hide"><div class"showBnt" click"showBox!showBox"><i class"el-icon-arrow-r…

【雕爷学编程】MicroPython动手做(28)——物联网之Yeelight 2

知识点&#xff1a;什么是掌控板&#xff1f; 掌控板是一块普及STEAM创客教育、人工智能教育、机器人编程教育的开源智能硬件。它集成ESP-32高性能双核芯片&#xff0c;支持WiFi和蓝牙双模通信&#xff0c;可作为物联网节点&#xff0c;实现物联网应用。同时掌控板上集成了OLED…

Linux armbian 如何防止暴力破解登陆root

对于 Armbian 或者其他基于 Linux 的系统&#xff0c;以下是一些常用的防止暴力破解登录 root 账户的方法&#xff1a; 禁止 root 用户远程登录&#xff1a;修改 /etc/ssh/sshd_config 文件&#xff0c;找到 PermitRootLogin 一行&#xff0c;将其修改为 PermitRootLogin no&a…

【SQL】-【计算两个varchar类型的timestamp的毫秒差】

背景 TRANSTAMP3、TRANSTAMP2在Oracle数据库中的类型为varchar&#xff0c;但实际保存的值是时间戳timestamp类型&#xff0c;现在要计算二者的毫秒差 Oracle或MySQL extract(second from (to_timestamp(TRANSTAMP3,yyyy-mm-dd hh24:mi:ss.ff) - to_timestamp(TRANSTAMP2,yyy…

opencv-29 Otsu 处理(图像分割)

Otsu 处理 Otsu 处理是一种用于图像分割的方法&#xff0c;旨在自动找到一个阈值&#xff0c;将图像分成两个类别&#xff1a;前景和背景。这种方法最初由日本学者大津展之&#xff08;Nobuyuki Otsu&#xff09;在 1979 年提出 在 Otsu 处理中&#xff0c;我们通过最小化类别内…

PHP从入门到精通—PHP开发入门-PHP概述、PHP开发环境搭建、PHP开发环境搭建、第一个PHP程序、PHP开发流程

每开始学习一门语言&#xff0c;都要了解这门语言和进行开发环境的搭建。同样&#xff0c;学生开始PHP学习之前&#xff0c;首先要了解这门语言的历史、语言优势等内容以及了解开发环境的搭建。 PHP概述 认识PHP PHP最初是由Rasmus Lerdorf于1994年为了维护个人网页而编写的一…