Pytorch单、多GPU和CPU训练模型保存和加载

news/2025/1/11 20:31:58/

Pytorch多GPU训练模型保存和加载

在多GPU训练中,模型通常被包装在torch.nn.DataParallel或torch.nn.parallel.DistributedDataParallel中,这会在模型的参数名前加上module前缀。因此,在保存模型时,需要使用model.module.state_dict()来获取模型的状态字典,以确保保存的参数名与模型定义中的参数名一致。(本质上原来的model还是存在的,参数也会同步更新)

  1. 多GPU训练模型保存
    在多GPU训练时,模型通常被包装在torch.nn.DataParallel或torch.nn.parallel.DistributedDataParallel中,这会在模型的参数名前加上module前缀。因此,在保存模型时,需要使用model.module.state_dict()来获取模型的状态字典,以确保保存的参数名与模型定义中的参数名一致。

  2. 单GPU或CPU加载模型
    当在单GPU或CPU上加载模型时,如果直接使用model.state_dict()保存的模型,由于缺少module前缀,会导致参数名不匹配,从而无法正确加载模型。因此,在保存多GPU训练的模型时,应该使用model.module.state_dict()来保存模型的状态字典,这样在单GPU或CPU上加载模型时,可以直接加载,不会出现参数名不匹配的问题。

  3. 示例代码
    以下是一个示例代码,展示了如何在多GPU训练时保存模型,并在单GPU或CPU上加载模型:

import torch
import torch.nn as nn
import os
os.os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"	#设置GPU编号
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 假设这是你的模型定义
class YourModel(nn.Module):def __init__(self):super(YourModel, self).__init__()self.fc = nn.Linear(10, 2)def forward(self, x):return self.fc(x)# 创建模型实例
model = YourModel()# 将模型移动到多GPU上
if torch.cuda.device_count() > 1:model = nn.DataParallel(model)model = model.to(device)
else:model = model.to(device)
······
# 假设这是你的训练代码,训练完成后保存模型
if torch.cuda.device_count() > 1:torch.save(model.module.state_dict(), 'model.pth')
else:torch.save(model.state_dict(), 'model.pth')# 在单、多GPU或CPU上加载模型
model = YourModel()
if torch.cuda.device_count() > 1:model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load('model.pth'))
model = model.to(device)

2 在多GPU训练得到的模型加载时,通常需要考虑以下几个步骤:

  1. 模型保存
    在多GPU训练时,模型通常被包装在torch.nn.DataParallel或torch.nn.parallel.DistributedDataParallel中。因此,在保存模型时,需要确保保存的是模型的state_dict而不是整个模型对象。例如:
if torch.cuda.device_count() > 1:torch.save(model.module.state_dict(), 'model.pth')
else:torch.save(model.state_dict(), 'model.pth')
  1. 模型加载
    在加载模型时,首先需要创建模型的实例,然后使用load_state_dict方法来加载保存的权重。如果模型是在多GPU环境下训练的,那么在加载时也应该使用torch.nn.DataParallel或torch.nn.parallel.DistributedDataParallel来包装模型。例如:
model = YourModel()
if torch.cuda.device_count() > 1:model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load('model.pth'))
model = model.to('cuda')
  1. 注意事项
    在加载模型时,需要注意以下几点:

如果模型是在多GPU环境下训练的,那么在加载时也应该使用相同数量的GPU,或者使用torch.nn.DataParallel来包装模型,即使只有一个GPU可用。
如果模型是在分布式训练环境下训练的,那么在加载时也应该使用torch.nn.parallel.DistributedDataParallel来包装模型。
如果模型是在混合精度训练(如使用了torch.cuda.amp)下训练的,那么在加载模型后,应该恢复之前的精度设置。

3 为了避免模型保存和加载出错

在多GPU训练的模型使用了torch.nn.DataParallel来包装模型,但本质上原来的model是依然存在的,且参数会同步更新:

  1. torch.nn.DataParallel 的工作原理
    torch.nn.DataParallel 是 PyTorch 提供的一个类,用于在多个 GPU 上并行训练模型。它的工作原理如下:
    模型复制:DataParallel 会在每个 GPU 上创建模型的副本。
    数据分发:输入数据会被分发到各个 GPU 上。
    前向传播:每个 GPU 上的模型副本会独立进行前向传播计算。
    梯度收集:所有 GPU 上的梯度会被收集并汇总到主 GPU 上。
    参数更新:主 GPU 上的优化器会根据汇总后的梯度更新模型参数,然后将更新后的参数同步回其他 GPU。
  2. 模型参数更新
    当你使用 model_train = torch.nn.DataParallel(model) 后,model_train 实际上是一个包装了原始模型 model 的对象。虽然 model_train 是多GPU并行的版本,但它的参数更新是通过主 GPU 上的优化器完成的,并且这些更新会同步回原始模型 model
    因此,model 的参数确实会被更新。具体来说:
    前向传播和反向传播:在 train_model 函数中,model_train 用于前向传播和反向传播。
    参数更新:优化器 optimizer 使用的是 model.parameters(),即原始模型的参数。在每次迭代中,优化器会根据汇总后的梯度更新这些参数。
    参数同步:更新后的参数会自动同步到 model_train 中的各个 GPU 副本。
    因此可以使用如下代码,加载模型和保存模型:
import torch
import torch.nn as nn
import os
os.os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"	#设置GPU编号
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 假设这是你的模型定义
class YourModel(nn.Module):def __init__(self):super(YourModel, self).__init__()self.fc = nn.Linear(10, 2)def forward(self, x):return self.fc(x)# 创建模型实例
model = YourModel()# 将模型移动到多GPU上,单GPU依然适用
if torch.cuda.device_count() > 1:model_train = nn.DataParallel(model)model_train = model_train.to(device)
else:model_train = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)#注意这是model的参数
······
output = model_train(input)	# 多卡时训练的输入和输出,注意这是model_train# 假设这是你的训练代码,训练完成后保存模型
torch.save(model.state_dict(), 'model.pth')	#注意这是model
  • 再在单/多GPU或CPU上加载模型,都不会报错,因为这里的model不是包装体,不带module
model = YourModel()
if torch.cuda.device_count() > 1:model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load('model.pth',map_location = device))
model = model.to(device)

http://www.ppmy.cn/news/1561887.html

相关文章

vulnhub靶场【DC系列】之6

前言 靶机:DC-6,IP地址为192.168.10.10 攻击:kali,IP地址为192.168.10.2 都采用VMWare,网卡为桥接模式 对于文章中涉及到的靶场以及工具,我放置在网盘中,链接:https://pan.quark…

碰一碰发视频的剪辑功能开发的细节源码搭建,支持OEM

在短视频盛行的今天,为碰一碰发视频增添剪辑功能,能极大提升用户创作的灵活性与趣味性。下面将详细阐述这一功能从技术选型到源码搭建的全过程。 一、技术选型 前端 框架:选择 React 作为前端框架,其基于组件化的开发模式&#x…

借助免费GIS工具箱轻松实现las点云格式到3dtiles格式的转换

在当今数字化浪潮下,地理信息系统(GIS)技术日新月异,广泛渗透到城市规划、地质勘探、文化遗产保护等诸多领域。而 GISBox 作为一款功能强大且易用的 GIS 工具箱,以轻量级、免费使用、操作便捷等诸多优势,为…

2024信息安全网络安全等安全意识(附培训PPT下载)

信息安全和网络安全是现代社会中至关重要的领域,它们涉及保护数据、系统和网络免受未经授权的访问、破坏和滥用。以下是一些关键的安全意识和概念: 信息安全意识 数据保护:意识到个人和组织数据的敏感性和价值,采取措施保护数据…

【MySQL实战】Centos安装MySQL

在CentOS上安装MySQL以及进行性能分析:2种方式,第一种直接装;第二种用docker安装: 直接安装MySQL 首先,更新系统软件包列表: sudo yum update然后,安装MySQL服务器: sudo yum in…

小程序学习08—— 系统参数获取和navBar组件样式动态设置

一 系统信息的概念 uni-app提供了异步(uni.getSystemInfo)和同步(uni.getSystemInfoSync)的2个API获取系统信息。 success 返回参数说明: 参数分类说明statusBarHeight手机状态栏的高度system操作系统名称及版本。。。 二 自定义navbar 2.1 获取系统参数 代码展示…

机器学习:从基础到前沿

引言 在当今这个数据爆炸的时代,机器学习已经成为了一项至关重要的技术。它赋予了计算机从数据中学习和做出决策的能力,从而在各行各业中发挥着越来越重要的作用。从医疗诊断到自动驾驶,从金融风险评估到个性化推荐系统,机器学习…

云原生架构:构建高效、可扩展的微服务系统

摘要 随着云计算技术的快速发展,云原生架构(Cloud Native)已经成为构建现代应用程序的主流趋势。云原生架构强调以容器、微服务、DevOps和持续集成/持续部署(CI/CD)为核心,以提高系统的可扩展性、弹性和灵活性。本文将探讨云原生架构的核心概念,并提供一个基于微服务的…