使用亚马逊针对 PyTorch 和 MinIO 的 S3 连接器进行模型检查点处理

embedded/2025/2/12 7:47:00/

2023 年 11 月,Amazon 宣布推出适用于 PyTorch 的 S3 连接器。适用于 PyTorch 的 Amazon S3 连接器提供了专为 S3 对象存储构建的 PyTorch 数据集基元(数据集和数据加载器)的实现。它支持用于随机数据访问模式的地图样式数据集和用于流式处理顺序数据访问模式的可迭代样式数据集。适用于 PyTorch 的 S3 连接器还包括一个检查点接口,用于将检查点直接保存和加载到 S3 存储桶,而无需先保存到本地存储。如果您还没有准备好采用正式的 MLOps 工具,而只需要一种简单的方法来保存模型,那么这是一个非常好的选择。这就是我将在这篇文章中介绍的内容。S3 连接器的文档仅展示了如何将其与 Amazon S3 一起使用 - 我将在此处向您展示如何将其用于 MinIO。让我们先执行此作 - 让我们设置 S3 连接器,以便它从 MinIO 写入和读取检查点。

将 S3 连接器连接到 MinIO

将 S3 连接器连接到 MinIO 就像设置环境变量一样简单。之后,一切都会顺利进行。诀窍是以正确的方式设置正确的环境变量。

本文的代码下载使用 .env 文件来设置环境变量,如下所示。此文件还显示了我用于使用 MinIO Python SDK 直接连接到 MinIO 的环境变量。请注意,AWS_ENDPOINT_URL 需要 protocol,而 MinIO 变量不需要。

AWS_ACCESS_KEY_ID=admin
AWS_ENDPOINT_URL=http://172.31.128.1:9000
AWS_REGION=us-east-1
AWS_SECRET_ACCESS_KEY=password
MINIO_ENDPOINT=172.31.128.1:9000
MINIO_ACCESS_KEY=admin
MINIO_SECRET_KEY=password
MINIO_SECURE=false

写入和读取 Checkpoint

我从一个简单的例子开始。下面的代码段创建了一个 S3Checkpointing 对象,并使用其 writer() 方法将模型的状态字典发送到 MinIO。我还使用 Torchvision 创建了一个 ResNet-18(18 层)模型,用于演示目的。

import osfrom dotenv import load_dotenv
from s3torchconnector import S3Checkpoint
import torchvision
import torch# Load the credentials and connection information.
load_dotenv()model = torchvision.models.resnet18()
model_name = 'resnet18.pth'
bucket_name = 'checkpoints'checkpoint_uri = f's3://{bucket_name}/{model_name}'
s3_checkpoint = S3Checkpoint(os.environ['AWS_REGION'])# Save checkpoint to S3
with s3_checkpoint.writer(checkpoint_uri) as writer:torch.save(model.state_dict(), writer)

请注意,该区域有一个强制参数。从技术上讲,访问 MinIO 时没有必要,但如果为此变量选择错误的值,内部检查可能会失败。此外,您的存储桶必须存在,上述代码才能正常工作。如果 writer() 方法不存在,它将引发错误。不幸的是,无论出了什么问题,writer() 方法都会引发相同的错误。例如,如果您的存储桶不存在,您将收到如下所示的错误。如果 writer() 方法不喜欢您指定的区域,您也会收到相同的错误。希望未来的版本将提供更具描述性的错误消息。

S3Exception: Client error: Request canceled

将以前保存的模型读取到内存中的代码类似于写入 MinIO。使用 reader() 方法,而不是 writer() 方法。下面的代码显示了如何执行此作。

import osfrom dotenv import load_dotenv
from s3torchconnector import S3Checkpoint
import torchvision
import torch# Load the credentials and connection information.
load_dotenv()model_name = 'resnet18.pth'
bucket_name = 'checkpoints'checkpoint_uri = f's3://{bucket_name}/{model_name}'
s3_checkpoint = S3Checkpoint(os.environ['AWS_REGION'])# Load checkpoint from S3
with s3_checkpoint.reader(checkpoint_uri) as reader:state_dict = torch.load(reader, weights_only=True)model.load_state_dict(state_dict)

接下来,让我们看看模型训练期间检查点的一些实际注意事项。

在模型训练期间编写检查点

如果您使用大型数据集训练大型模型,请考虑在每个 epoch 后设置检查点。这些训练运行可能需要数小时甚至数天才能完成,因此在发生故障时能够从上次中断的地方继续非常重要。此外,我们假设您必须使用共享存储桶来保存来自多个团队的多个模型的模型检查点。MLOps 约定是按试验组织训练运行。例如,如果您正在研究具有四个隐藏层的架构,那么在寻找各种超参数的最佳值时,您将使用此架构进行多次运行。如果同事使用五层体系结构运行实验,则需要一种方法来防止名称冲突。这可以通过模拟如下所示的层次结构的对象路径来解决。

最后,为了确保您在每个 epoch 中获得新版本的模型,请确保在用于保存检查点的存储桶上启用版本控制。下面的训练函数使用上述路径结构在每个 epoch 后对模型进行检查点作。(可以在本文的代码下载中找到此训练函数的更强大版本。

def train_model(model: nn.Module, loader: DataLoader, training_parameters: Dict[str, Any]) -> List[float]:if training_parameters['checkpoint']:checkpoint_uri = f's3://{training_parameters["checkpoint_bucket"]} \/{training_parameters["project_name"]} \/{training_parameters["experiment_name"]} \/{training_parameters["run_id"]} \/{training_parameters["model_name"]}'s3_checkpoint = S3Checkpoint(region=os.environ['AWS_REGION'])loss_func = nn.NLLLoss()optimizer = optim.SGD(model.parameters(), lr=training_parameters['lr'], momentum=training_parameters['momentum'])# Epoch loopcompute_time_by_epoch = []for epoch in range(training_parameters['epochs']):# Batch loopfor images, labels in loader:# Flatten MNIST images into a 784 long vector.# shape = [32, 784]images = images.view(images.shape[0], -1)# Training passoptimizer.zero_grad()output = model(images)loss = loss_func(output, labels)loss.backward()optimizer.step()# Save checkpoint to S3if training_parameters['checkpoint']:with s3_checkpoint.writer(checkpoint_uri) as writer:torch.save(model.state_dict(), writer)

请注意,模型名称不包含指示纪元的子字符串。如前所述,我使用了启用了版本控制的存储桶 - 换句话说,版本号表示纪元。这种方法的优点在于,您无需知道引用最新模型的 epoch 数。在上述训练代码运行了 10 个 epoch 后,我的检查点存储桶如下面的屏幕截图所示。

此培训演示可被视为 DIY MLOps 解决方案的开始。

结论

适用于 PyTorch 的 S3 连接器易于使用,工程师在使用时编写的数据访问代码行数会更少。在本文中,我展示了如何将其配置为使用环境变量连接到 MinIO。配置完成后,工程师可以分别使用 writer() 和 reader() 方法将检查点写入和读取 MinIO。在本文中,我展示了如何配置 S3 Connect 以连接到 MinIO。我还演示了 S3Checkpoint 类及其 reader() 和 writer() 方法的基本用法。最后,我展示了一种在实际训练函数中针对启用了版本的检查点存储桶使用这些检查点功能的方法。在这篇文章中,我没有介绍在分布式训练期间检查点所需的技术和工具,这可能有点棘手。分布式训练期间的检查点设置会有所不同,具体取决于您使用的框架(PyTorch、Ray 或 DeepSpeed 等)和您正在进行的分布式训练类型:数据并行(每个工作程序都有模型的完整副本)或模型并行(每个工作程序只有一个模型分片)。在以后的文章中,我将介绍其中的一些技术。


http://www.ppmy.cn/embedded/161542.html

相关文章

香港中文大学 Adobe 推出 MotionCanvas:开启用户掌控的电影级图像视频创意之旅。

简介: 亮点直击 将电影镜头设计引入图像到视频的合成过程中。 推出了MotionCanvas,这是一种简化的视频合成系统,用于电影镜头设计,提供整体运动控制,以场景感知的方式联合操控相机和对象的运动。 设计了专门的运动条…

安卓开发,底部导航栏

1、创建导航栏图标 使用系统自带的矢量图库文件,鼠标右键点击res->New->Vector Asset 修改 Name , Clip art 和 Color 再创建一个 同样的方法再创建四个按钮 2、添加百分比布局依赖 app\build.gradle.kts 中添加百分比布局依赖,并点击Sync Now …

如何解决ChatGPT API响应慢的问题

随着人工智能技术的快速发展,OpenAI的ChatGPT API已广泛应用于多种场景中,从客户服务到内容创作,甚至在教育、娱乐等领域也有着重要的应用。然而,很多开发者和使用者会遇到一个共同的问题——ChatGPT API响应速度较慢,…

kafka在初始化集群配置当中有哪些重要参数?

在初始化 Kafka 集群配置时,有一些重要的参数需要正确设置,以确保集群的性能、可靠性和可用性。这些参数分为不同的类别,包括 broker 配置、topic 配置和消费者配置。以下是一些关键参数及其作用: Broker 配置 broker.id 描述&a…

Unity使用iTextSharp导出PDF-02基础结构及设置中文字体

基础结构 1.创建一个Document对象 2.使用PdfWriter创建PDF文档 3.打开文档 4.添加内容,调用文档Add方法添加内容时,内容写入到输出流中 5.关闭文档 using UnityEngine; using iTextSharp.text; using System.IO; using iTextSharp.text.pdf; using Sys…

Unity3D实现显示模型线框(shader)

系列文章目录 unity工具 文章目录 系列文章目录👉前言👉一、效果展示👉二、第一种方式👉二、第二种方式👉壁纸分享👉总结👉前言 在 Unity 中显示物体线框主要基于图形渲染管线和特定的渲染模式。 要显示物体的线框,通常有两种常见的方法:一种是利用内置的渲染…

未来替代手机的产品,而非手机的本身

替代手机的产品包括以下几种: 可穿戴设备:智能手表、智能眼镜等可穿戴设备可以提供类似手机的功能,如通话、信息推送、浏览网页等。 虚拟现实(VR)技术:通过佩戴VR头显,用户可以进行语音通话、发…

maven web项目如何定义filter

在 Maven Web 项目中定义一个 Servlet 过滤器(Filter),需要遵循 Java Servlet 规范,并利用 Maven 来管理项目结构和依赖。下面是如何在 Maven Web 项目中定义和配置一个过滤器的基本步骤: 1. 创建过滤器类 首先&…