Chapter5.4 Loading and saving model weights in PyTorch

embedded/2025/1/18 5:07:01/

5 Pretraining on Unlabeled Data

5.4 Loading and saving model weights in PyTorch

  • 训练LLM的计算成本很高,因此能够保存和加载LLM的权重至关重要。

  • 在PyTorch中,推荐的方式是通过将torch.save函数应用于.state_dict()方法来保存模型权重,即所谓的state_dict

    python">torch.save(model.state_dict(),"model.pth")
    

    我们可以将模型权重加载到新的 GPTModel 模型实例中

    python">model = GPTModel(GPT_CONFIG_124M)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.load_state_dict(torch.load("model.pth", map_location=device, weights_only=True))
    model.eval();
    
  • 自适应优化器(如 AdamW)为每个模型权重存储额外的参数。AdamW 使用历史数据动态调整每个模型参数的学习率。如果没有这些参数,优化器会重置,模型可能会学习效果不佳,甚至无法正确收敛,这意味着模型将失去生成连贯文本的能力。使用 torch.save,我们可以保存模型和优化器的 state_dict 内容,如下所示

    python">torch.save({"model_state_dict": model.state_dict(),"optimizer_state_dict": optimizer.state_dict(),}, "model_and_optimizer.pth"
    )
    

    然后,我们可以通过以下方式恢复模型和优化器状态:首先通过 torch.load 加载保存的数据,然后使用 load_state_dict 方法:

    python">checkpoint = torch.load("model_and_optimizer.pth", weights_only=True)model = GPTModel(GPT_CONFIG_124M)
    model.load_state_dict(checkpoint["model_state_dict"])optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005, weight_decay=0.1)
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    model.train();
    


http://www.ppmy.cn/embedded/154860.html

相关文章

Redis监控系统:基于Redis Exporter的性能指标可视化

Redis监控系统:基于Redis Exporter的性能指标可视化 一、什么是Redis监控系统 监控系统是用于监控Redis数据库运行状态和性能指标的系统工具。通过对Redis数据库的监控可以及时发现问题,分析性能瓶颈,优化系统运行效率。 二、Redis Exporter简介 是一个用…

内存分区模型

栈区(stack) 函数调用。编译器自动分配释放,存放 函数的参数值和局部变量等 堆区(heap) 动态分配的变量。由程序员动态分配和释放,使用new 和 delete 全局/静态存储区(Data Segment & …

SpringBoot3+Vue3开发台球计时系统

项目介绍 台球计时系统可以帮助我们自动计算开台时间(从开始到结束的时间段)、自动计算开台费用、结账后生成订单记录进行留存、也可以导出订单记录。 主要功能包含:球桌管理、开台、结账、查看占用明细、查看球台订单、订单管理、查看订单…

SpringBoot 基于 Redisson 分布式锁实现

1. 分布式锁 1.1 为什么要使用分布式锁 以下是使用分布式锁的一些主要原因: 保持数据一致性: 在分布式系统中,数据一致性是至关重要的。使用分布式锁可以防止并发更新导致的数据不一致问题,确保数据在所有节点之间保持一致避免…

【开源分享】nlohmann C++ JSON解析库

文章目录 1. Nlohmann JSON 库介绍2. 编译和使用2.1 获取库2.2 包含头文件2.3 使用示例2.4 编译 3. 优势4. 缺点5. 总结参考 1. Nlohmann JSON 库介绍 Nlohmann JSON 是一个用于 C 的现代 JSON 库,由 Niels Lohmann 开发。它以易用性和高性能著称,支持 …

Canvas简历编辑器-选中绘制与拖拽多选交互方案

Canvas简历编辑器-选中绘制与拖拽多选交互方案 在之前我们聊了聊如何基于Canvas与基本事件组合实现了轻量级DOM,并且在此基础上实现了如何进行管理事件以及多层级渲染的能力设计。那么此时我们就依然在轻量级DOM的基础上,关注于实现选中绘制与拖拽多选交…

基于Java的百度AOI数据解析与转换的实现方法

目录 前言 一、AOI数据结构简介 1、官网的实例接口 2、响应参数介绍 二、Java对AOI数据的解析 1、数据解析流程图 2、数据解析实现 3、AOI数据解析成果 三、总结 前言 在当今信息化社会,地理信息数据在城市规划、交通管理、商业选址等领域扮演着越来越重要的…

采用海豚调度器+Doris开发数仓保姆级教程(满满是踩坑干货细节,持续更新)

目录 一、采用海豚调度器+Doris开发平替CDH Hdfs + Yarn + Hive + Oozie的理由。 1. 架构复杂性 2. 数据处理性能 3. 数据同步与更新 4. 资源利用率与成本 6. 生态系统与兼容性 7. 符合信创或国产化要求 二、ODS层接入数据 接入kafka实时数据 踩坑的问题细节 三、海…