训练集,验证集,测试集的作用

embedded/2024/9/23 14:27:34/

训练集 (Training Set), 验证集 (Validation Set) 和测试集 (Test Set) 是机器学习和深度学习模型开发过程中不可或缺的部分。它们的主要作用和区别如下:

  1. 训练集

    • 作用:用于训练模型,调整模型的参数(如神经网络的权重)。
    • 示例:如果你在训练一个猫狗分类器,训练集中包含大量标记为“猫”或“狗”的图片。模型通过这些数据学习如何区分猫和狗。
  2. 验证集

    • 作用:用于调参和选择最佳模型。通过验证集,我们可以评估模型在未见过的数据上的表现,防止过拟合。
    • 示例:在训练猫狗分类器时,验证集中的数据也标记为“猫”或“狗”,但这些数据不用于训练,而是用于在训练过程中评估模型性能。
  3. 测试集

    • 作用:用于评估最终模型的性能。测试集的结果代表了模型在实际应用中的表现。
    • 示例:在猫狗分类器中,测试集包含的图片同样标记为“猫”或“狗”,但这些数据既不用于训练,也不用于调参,而是用于最终评估模型。

为什么要分为这三个集?假设我们不分开数据集,将所有数据用于训练,那么模型可能会记住训练数据,而无法泛化到新数据(即过拟合)。验证集和测试集的引入能够帮助我们检测这种情况,并选择或调整模型以提高其泛化能力。

下面是一个使用PyTorch实现简单神经网络并进行训练、验证和测试的示例代码,每行都有详细注释:

python">import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, TensorDataset# 假设我们有一些数据
data = torch.randn(1000, 20)  # 1000个样本,每个样本20个特征
labels = torch.randint(0, 2, (1000,))  # 二分类任务,标签为0或1# 创建一个TensorDataset
dataset = TensorDataset(data, labels)# 将数据集划分为训练集、验证集和测试集
train_size = int(0.7 * len(dataset))  # 70%的数据用于训练
val_size = int(0.15 * len(dataset))  # 15%的数据用于验证
test_size = len(dataset) - train_size - val_size  # 剩余的15%用于测试train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])# 创建DataLoader以便于批量训练和评估
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)# 定义一个简单的神经网络
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(20, 64)self.fc2 = nn.Linear(64, 32)self.fc3 = nn.Linear(32, 1)def forward(self, x):x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))x = torch.sigmoid(self.fc3(x))return x# 实例化神经网络,定义损失函数和优化器
model = SimpleNN()
criterion = nn.BCELoss()  # 二分类任务使用的损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
num_epochs = 10
for epoch in range(num_epochs):model.train()  # 设置模型为训练模式for batch_data, batch_labels in train_loader:outputs = model(batch_data).squeeze()  # 前向传播loss = criterion(outputs, batch_labels.float())  # 计算损失optimizer.zero_grad()  # 清空梯度loss.backward()  # 反向传播optimizer.step()  # 更新参数# 在验证集上评估模型model.eval()  # 设置模型为评估模式val_loss = 0.0with torch.no_grad():  # 禁用梯度计算for batch_data, batch_labels in val_loader:outputs = model(batch_data).squeeze()  # 前向传播loss = criterion(outputs, batch_labels.float())  # 计算损失val_loss += loss.item()  # 累加损失print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {loss.item():.4f}, Validation Loss: {val_loss/len(val_loader):.4f}')# 在测试集上最终评估模型
model.eval()  # 设置模型为评估模式
test_loss = 0.0
with torch.no_grad():  # 禁用梯度计算for batch_data, batch_labels in test_loader:outputs = model(batch_data).squeeze()  # 前向传播loss = criterion(outputs, batch_labels.float())  # 计算损失test_loss += loss.item()  # 累加损失print(f'Test Loss: {test_loss/len(test_loader):.4f}')

这个示例展示了如何在PyTorch中划分数据集并训练、验证和测试一个简单的神经网络模型。通过这种方式,我们可以确保模型在不同的数据集上有良好的表现,从而提高模型的泛化能力。


http://www.ppmy.cn/embedded/58936.html

相关文章

[计算机网络] VPN技术

VPN技术 1. 概述 虚拟专用网络(VPN)技术利用互联网服务提供商(ISP)和网络服务提供商(NSP)的网络基础设备,在公用网络中建立专用的数据通信通道。VPN的主要优点包括节约成本和提供安全保障。 优…

用SmartSql从数据库表中导出文档

在 SmartSql 中从数据库表中导出文档通常意味着将表结构和数据导出为文档格式,比如 Word、PDF、HTML 或者 Markdown。这通常涉及到以下步骤: 连接到数据库: 打开 SmartSql 客户端,并确保已成功连接到你的目标数据库。你需要提供正…

3D Web开发新篇章:threelab探索之旅

3D Web开发新篇章:threelab探索之旅 随着数字技术的飞速发展,三维图形技术已经渗透到我们生活的每一个角落,从在线游戏到数字艺术,再到虚拟现实体验。今天,我们将探索一个全新的学习平台——threelab.cn,它…

逻辑回归不是回归吗?那为什么叫回归?

RNN 逻辑回归不是回归吗?那为什么叫回归?逻辑回归的基本原理逻辑函数(Sigmoid函数)二元分类 为什么叫做“回归”?逻辑回归的应用场景总结 逻辑回归不是回归吗?那为什么叫回归? 逻辑回归&#x…

提交表单form之后发送表单内容到指定邮箱(单php文件实现)

提交各种表单之后,自动将表单的内容通过邮件api接口的形式自动发送到指定的邮箱。步骤如下: 1.在aoksend注册一个账号。 2.绑定一个自己的域名。做域名解析之后验证。验证通过后自动提交审核。等待审核通过。 3.设置一个邮件模板。aoksend内置了一些优…

Java并发编程之多线程实现方法

Java实现多线程的方式有比较多,但究其本质,最终都是在执行Thread的run方法,这个后文再作解释。下面先看看各种实现方式。 实现 Runnable 接口 public class RunnableThread implements Runnable{Overridepublic void run() {System.out.pr…

Django ModelForm用法详解 —— Python

Django ModelForm是一种自动生成表单的工具,它是以模型为基础,在模型类上定义的表单。在使用Django ModelForm时,我们只需要指定模型类作为表单数据的基础,就可以自动地生成表单。下面是Django ModelForm用法的完整攻略。 创建Mo…

【Flask从入门到精通:第二课:flask加载项目配置的二种方式、路由的基本定义和终端运行】

flask加载项目配置的二种方式 # 1. 导入flask核心类 from flask import Flask# 2. 初始化web应用程序的实例对象 app Flask(__name__)"""第一种:flask项目加载站点配置的方式""" # app.config["配置项"] 配置项值 # app…