公共检查点(checkpoints)+探针(Probe)详解

server/2024/11/16 22:02:17/

一、概念介绍

        “公共检查点”(checkpoints)是指在模型训练过程中保存的模型参数和状态。这些检查点通常在模型训练完成后或者在特定的训练阶段被保存下来,以便后续可以重新加载模型并继续训练或者用于模型评估。

        其包括(1)模型权重:模型的所有参数,包括权重和偏置;(2)优化器状态:优化器的状态,包括动量、学习率等;(3)训练状态:当前的训练轮数(epoch)、批次(batch)编号等;(4)其他元数据:如学习率调度器的状态、自定义指标等。

二、具体应用

        《Probing the 3D Awareness of Visual Foundation Models》这篇论文通过一系列实验,使用特定任务的探针和零样本推理程序来分析这些模型的3D感知能力,并发现当前模型存在一些限制,其中对于公共检查点和探针起到了详细的应用。

        这篇论文的相关分析如下:

        《Probing the 3D Awareness of Visual Foundation Models》论文解析——单图像表面重建

        《Probing the 3D Awareness of Visual Foundation Models》论文解析——多视图一致性

        在论文 "Probing the 3D Awareness of Visual Foundation Models" 中,获取公共检查点的过程遵循了以下步骤

  1. 选择模型:研究者首先确定了一组他们想要评估的视觉基础模型,这些模型包括了不同监督信号下训练的模型,例如分类、语言监督、自监督等。
  2. 公开可用性:研究者选择了那些公开可用的检查点,这些检查点是模型训练过程中保存的参数和状态的快照。这些检查点通常由模型的原始训练者发布,可以在模型的官方代码库、研究论文的补充材料或者专门的模型共享平台上找到。
  3. 比较架构和规模:为了公平比较,研究者尝试选择在架构和训练规模上可比的检查点。这意味着他们会选择相似的模型大小和训练数据量的检查点,以便评估时控制变量。
  4. 下载和使用:研究者从公开的资源下载这些检查点,并在他们的实验设置中使用这些预训练的模型。这些检查点允许研究者评估模型的冻结特征,而不需要自己从头开始训练模型。
  5. 实验分析:研究者通过特定的任务特定的探针(probes)或零样本(zero-shot)推断方法来评估这些冻结特征的模型,从而分析模型的3D意识。

三、为什么需要Checkpoint?

        在机器学习深度学习中,checkpoints(检查点)是一种重要的技术,用于保存训练过程中的关键信息。以下是详细介绍为什么需要checkpoints的几个关键原因:

1. 防止训练过程中断

        训练深度学习模型通常需要大量的时间和计算资源。在训练过程中,可能会遇到各种意外情况,如硬件故障、电力中断或程序错误等,这些都可能导致训练任务意外中断。使用checkpoints可以在训练过程中定期保存模型的状态,这样即使训练被中断,也可以从最近的checkpoint恢复,而不是从头开始,从而节省时间和资源。

2. 早停(Early Stopping)

        在训练过程中,我们通常希望模型在验证集上的性能达到最佳。通过设置checkpoints,我们可以在每个epoch后评估模型的性能,并在性能不再提升时停止训练,这称为早停。这样做可以防止模型过拟合,并节省不必要的计算资源。

3. 模型选择和超参数调整并且避免过拟合

        在训练多个模型或尝试不同的超参数时,checkpoints允许我们保存每个模型的中间状态。这样,我们可以比较不同模型的性能,并选择最佳的模型进行进一步的训练或部署,而无需重新训练。深度学习模型在训练过程中可能会过拟合到训练数据。通过在训练过程中保存checkpoints,我们可以在不同的训练阶段恢复模型,选择在验证集上表现最佳的模型,从而减少过拟合的风险。

4. 实验复现与模型版本控制

在科学研究和实验中,复现他人的结果是非常重要的。使用checkpoints可以确保实验的可复现性,因为它们保存了模型的权重和优化器状态,使得其他研究者可以加载相同的checkpoints并复现实验结果。在实际应用中,可能需要部署多个版本的模型以支持A/B测试或逐步推出新模型。Checkpoints可以帮助管理和部署这些不同版本的模型。

5. 资源管理与模型微调

        深度学习训练通常需要大量的计算资源。在多任务环境中,checkpoints允许系统在资源紧张时暂停训练任务,并在资源可用时恢复训练,从而更有效地管理计算资源。在迁移学习或微调预训练模型时,checkpoints可以保存预训练的模型权重。这样,在微调过程中,我们可以从预训练的权重开始,而不是从头开始训练,这可以显著加快训练速度并提高模型性能。

四、Checkpoint模型在Stable Diffusion中的应用

        在Stable Diffusion中,Checkpoint模型被广泛应用。由于Stable Diffusion是一种生成模型,需要训练大量的数据和时间,因此,使用Checkpoint模型可以有效地避免训练过程中的意外导致的资源浪费。同时,Checkpoint模型还可以帮助我们更好地管理训练过程,我们可以根据需要选择加载不同的检查点,从而实现不同的训练效果。

        在实践中,使用Checkpoint模型需要注意以下几点:

  1. 选择合适的检查点保存频率:检查点保存频率过低可能会导致训练过程中的意外导致大量的资源浪费,而保存频率过高则会占用大量的存储空间。因此,我们需要根据实际情况选择合适的保存频率。
  2. 管理好检查点文件:在训练过程中,会产生大量的检查点文件,这些文件需要妥善管理。我们可以使用版本控制工具(如Git)来管理这些文件,以便在需要时能够方便地找到和加载正确的检查点文件。
  3. 选择合适的恢复策略:当需要恢复训练时,我们需要选择合适的恢复策略。一般来说,我们可以选择加载最近的检查点文件,也可以选择加载在某个特定性能点上的检查点文件,这需要根据实际情况来决定。

五、示例代码分析

        在深度学习中,使用checkpoints通常意味着在训练过程中保存模型的权重和优化器的状态,以便在需要时可以恢复训练。以下是使用TensorFlow和PyTorch这两个流行的深度学习框架的checkpoints示例代码。

        1.TensorFlow Checkpoints 示例

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import ModelCheckpoint# 构建一个简单的模型
model = Sequential([Dense(64, activation='relu', input_shape=(100,)),Dense(10, activation='softmax')
])model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 设置ModelCheckpoint回调函数
checkpoint_cb = ModelCheckpoint('best_model.h5', save_best_only=True)# 训练模型并保存checkpoints
model.fit(x_train, y_train, epochs=10, validation_data=(x_val, y_val), callbacks=[checkpoint_cb])# 加载最佳模型
model.load_weights('best_model.h5')

         在这个例子中,ModelCheckpoint 回调函数会在每个epoch结束后保存模型。save_best_only=True 表示只保存在验证集上性能最好的模型。

        2.PyTorch Checkpoints 示例

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms# 定义一个简单的网络
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(100, 50)self.fc2 = nn.Linear(50, 10)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 实例化网络、损失函数和优化器
net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)# 加载数据
train_loader = DataLoader(datasets.MNIST('', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=32, shuffle=True)# 训练网络
for epoch in range(2):  # 这里的2代表epoch的数量for i, data in enumerate(train_loader):inputs, labels = dataoptimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 保存checkpoints
torch.save({'epoch': epoch,'model_state_dict': net.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,
}, 'checkpoint.pth')# 加载checkpoints
checkpoint = torch.load('checkpoint.pth')
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        在这个PyTorch的例子中,我们定义了一个简单的神经网络,并在训练过程中保存了模型的状态字典和优化器的状态字典。这样,我们可以在以后从这个状态恢复训练。


http://www.ppmy.cn/server/142475.html

相关文章

ESLint 使用教程(七):ESLint还能校验JSON文件内容?

系列文章 ESLint 使用教程(一):从零配置 ESLint ESLint 使用教程(二):一步步教你编写 Eslint 自定义规则 ESLint 使用教程(三):12个ESLint 配置项功能与使用方式详解 ES…

Android CALL按键同步切换通话界面上免提和听筒的图标显示

按一下call按键,进行切换图标,分别显示为免提和听筒模式! /frameworks/base/services/core/java/com/android/server/policy/PhoneWindowManager.java case KeyEvent.KEYCODE_CALL: { //*/ add custom key. if("com.freeme.factory.in…

git常用命令+搭vscode使用

1.克隆远程代码 git clone http:xxx git clone ssh:xxx clone的url 中 https和 ssh是有区别的: git中SSH和HTTP连接有什么区别-CSDN博客 当然https拉下来的代码每次pull /push都需要验证一次自己的账户和密码,可以config进行配置不用每次手敲: 解决VScode中每次git pu…

MFC中Picture Control控件显示照片的几种方式

目前使用CImage和CBitmap两个类,还有是将CImage转CBitmap显示。 MFC界面拖拽一个button按钮和一个Picture Control控件。 1.CImage显示。这种方式显示图片会有颜色不对的情况 void Cpicture_test_controlDlg::OnBnClickedButton1() {// TODO: 在此添加控件通知处…

新增支持Elasticsearch数据源,支持自定义在线地图风格,DataEase开源BI工具v2.10.2 LTS发布

2024年11月11日,人人可用的开源BI工具DataEase正式发布v2.10.2 LTS版本。 这一版本的功能变动包括:数据源方面,新增了对Elasticsearch数据源的支持;图表方面,对地图类和表格类图表进行了功能增强和优化,增…

centos7安装Chrome使用selenium-wire

背景:在centos7中运行selenium-wire爬虫,系统自带的Firefox浏览器不兼容,运行报错no attribute ‘set_preference’,应该是selenium-wire和Firefox的驱动不兼容 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上…

报错 No available slot found for the embedding model

报错内容 Server error: 503 - [address0.0.0.0:12781, pid304366] No available slot found for the embedding model. We recommend to launch the embedding model first, and then launch the LLM models. 目前GPU占用情况如下 解决办法: 关闭大模型, 先把 embedding mode…

微信小程序之路由跳转传数据及接收

跳转并传id或者对象 1.home/index.wxml <!--点击goto方法 将spu_id传过去--> <view class"item" bind:tap"goto" data-id"{{item.spu_id}}"> 结果: 2.home/index.js goto(event){// 路由跳转页面,并把id传传过去//获取商品idlet i…