PyTorch中保存模型的两种方式

news/2024/11/29 2:35:08/

文章目录

  • 一、状态字典(State Dictionary)
  • 二、序列化模型(Serialized Model)
  • 三、示例代码


一、状态字典(State Dictionary)

这种保存形式将模型的参数保存为一个字典,其中包含了所有模型的权重和偏置等参数。状态字典保存了模型在训练过程中学到的参数值,而不包含模型的结构。可以使用这个字典来加载模型的参数,并将其应用于相同结构的模型。
在 PyTorch 中,您可以使用 torch.save() 函数将模型的状态字典保存到文件中,例如:

torch.save(model.state_dict(), 'model.pth')

然后,可以使用 torch.load() 函数加载状态字典并将其应用于相同结构的模型:

model = MyModel()  # 创建模型对象
model.load_state_dict(torch.load('model.pth'))

这种保存形式非常适用于仅保存和加载模型的参数,而不需要保存和加载模型的结构。

二、序列化模型(Serialized Model)

这种保存形式将整个模型(包括模型的结构、参数等)保存为一个文件。序列化模型保存了模型的完整信息,可以完全恢复模型的状态,包括模型的结构、权重、偏置以及其他相关参数。
在 PyTorch 中,您可以使用 torch.save() 函数直接保存整个模型对象,例如:

torch.save(model, 'model.pth')

然后,您可以使用 torch.load() 函数加载整个序列化模型:

model = torch.load('model.pth')

这种保存形式适用于需要保存和加载完整模型信息的情况,包括模型的结构和参数。

三、示例代码

import torchclass LinearNet(torch.nn.Module):def __init__(self, input_size, output_size):super().__init__()self.net = torch.nn.Sequential(torch.nn.Linear(in_features=input_size, out_features= 5, bias=True),torch.nn.Sigmoid(),torch.nn.Linear(in_features= 5, out_features=5, bias=True),torch.nn.Sigmoid(),torch.nn.Linear(in_features=5, out_features=output_size, bias=True))def forward(self,x):return self.net(x)square_net = LinearNet(1,1)# square_net.load_state_dict(torch.load('weight.pth'))  #直接加载已经训练好的权重if __name__ == '__main__':# print(square_net(torch.tensor([3.16],dtype=torch.float32)))# save 方式1torch.save(square_net.state_dict(), "./w1.pth")my_state_dict = torch.load("./w1.pth")print("纯state_dict:\n", my_state_dict)print("type:", type(my_state_dict))# save 方式2torch.save(square_net, "./w2.pth")my_state_dict = torch.load("./w2.pth")print("\n\n模型结构:\n", my_state_dict)print("type:", type(my_state_dict))# 执行结果'''纯state_dict:OrderedDict([('net.0.weight', tensor([[ 0.0820],[-0.6923],[ 0.5066],[-0.8931],[ 0.0460]])), ('net.0.bias', tensor([ 0.1455,  0.5106,  0.2347,  0.4903, -0.6838])), ('net.2.weight', tensor([[-0.4055, -0.2721,  0.3770, -0.2285,  0.3025],[-0.0416,  0.0133, -0.3834, -0.2151,  0.1454],[ 0.0749, -0.3664, -0.1901, -0.2829,  0.3957],[-0.3567,  0.2668,  0.3343, -0.3351, -0.3808],[ 0.4375,  0.1000,  0.1185,  0.2295, -0.3997]])), ('net.2.bias', tensor([-0.2405, -0.2751,  0.1928,  0.3970, -0.0005])), ('net.4.weight', tensor([[-0.4388, -0.2654,  0.3038,  0.2008,  0.0381]])), ('net.4.bias', tensor([0.1847]))])模型结构:LinearNet((net): Sequential((0): Linear(in_features=1, out_features=5, bias=True)(1): Sigmoid()(2): Linear(in_features=5, out_features=5, bias=True)(3): Sigmoid()(4): Linear(in_features=5, out_features=1, bias=True)))'''

http://www.ppmy.cn/news/1360276.html

相关文章

“软件定义汽车”时代下的软件供应链安全

如今汽车产业智能化、网联化、电动化、共享化的“新四化”程度逐渐深入,“软件定义汽车”也被反复提及。以硬件主导的传统汽车演变为以软件主导、软硬解耦的新汽车。“中国亟需构建智能网联汽车安全可信的软件生态。”中国工程院院士沈昌祥此前曾表示,没…

有趣且重要的JS知识合集(19)前端实现图片的本地上传/截取/导出

input[file]太丑了,又不想去改button样式,那就自己实现一个上传按钮的div,然后点击此按钮时,去触发file上传的事件, 以下就是 原生js实现图片前端上传 并且按照最佳宽高比例展示图片,然后可以自定义截取图片&#xff0…

链表之“无头单向非循环链表”

目录 ​编辑 1.顺序表的问题及思考 2.链表 2.1链表的概念及结构 2.2无头单向非循环链表的实现 1.创建结构体 2.单链表打印 3.动态申请一个节点 3.单链表尾插 4.单链表头插 5.单链表尾删 6.单链表头删 7.单链表查找 8.单链表在pos位置之前插入x 9.单链表删除pos位…

一文了解LM317T的引脚介绍、参数解读

LM317T是一种线性稳压器件,它具有稳定输出电压的特性。LM317T可以通过调整其输出电阻来确保输出电压的稳定性,因此被广泛应用于各种电子设备中。 LM317T引脚图介绍 LM317T共有3个引脚,分别是: 输入引脚(输入电压V_in&…

文本编辑器markdown语法

markdown语法 1.介绍 Markdown是一种使用一定的语法将普通的文本转换成HTML标签文本的编辑语言,它的特点是可以使用普通的文本编辑器来编写,只需要按照特定的语法标记就可以得到丰富多样的HTML格式的文本。 2.标题分级 "# " -> 一级标题 &…

[概念区分] 正则表达式与正则化

正则表达式与正则化 机器学习在计算机科学和数据处理领域,关于“正则”的两个术语:正则表达式和正则化,虽然它们在名称上非常相似,但实际上它们是完全不同的概念。 正则表达式 也被称为 regex,是一种强大的工具&…

RRT算法学习及MATLAB演示

文章目录 1 前言2 算法简介3 MATLAB实现3.1 定义地图3.2 绘制地图3.3 定义参数3.4 绘制起点和终点3.5 RRT算法3.5.1 代码3.5.2 效果3.5.3 代码解读 4 参考5 完整代码 1 前言 RRT(Rapid Random Tree)算法,即快速随机树算法,是LaVa…

基于qt的图书管理系统----03核心界面设计

参考b站:视频连接 源码github:github 目录 1 添加软件图标2 打包程序3 三个管理界面设计4 代码编写4.1 加载界面4.2 点击按钮切换界面4.3 组团添加样式4.4 搭建表头4.5 表格相关操作 从别人那里下载的项目会有这个文件,里边是别人配置的路径…