Pytorch学习--神经网络--网络模型的保存与读取

devtools/2024/11/15 4:38:41/

一、网络模型的保存与读取方式1

方法讲解

在这里插入图片描述
在这里插入图片描述

保存模型

python">import torch
import torchvision
model = torchvision.models.vgg16(weights='DEFAULT')
#保存模型和参数
torch.save(model,"save_method1.pth")

读取模型

python">import torch
model = torch.load("save_method1.pth")
print(model)

输出:在这里插入图片描述

比较坑人的点

使用 torch.save 必须将该模型的架构引入到该文件中(可以使用from A import B的方式来解决),这里举一个例子来说明

保存模型

python">import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear#保存模型和参数class Mary(nn.Module):def __init__(self):super(Mary,self).__init__()self.model1 = nn.Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self,x):x = self.model1(x)return x
Yorelee = Mary()
torch.save(Yorelee,"save_method1_question.pth")

读取模型

python">import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linearmodel = torch.load("save_method1_question.pth")print(model)

报错如下

在这里插入图片描述
说明我们还要把 Mary 这个框架复制到读取模型的.py文件中

重新更正后的读取模型代码

python">import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linearclass Mary(nn.Module):def __init__(self):super(Mary,self).__init__()self.model1 = nn.Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self,x):x = self.model1(x)return xmodel = torch.load("save_method1_question.pth")print(model)
或者
python">import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
from torch_save import Mary   #这里仅举一个例子model = torch.load("save_method1_question.pth")print(model)

二、网络模型的保存与读取方式2

保存模型参数

python">import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linearvgg_model = torchvision.models.vgg16(weights='DEFAULT')
#保存参数
torch.save(vgg_model.state_dict(),"save_method2.pth")

读取模型参数

python">import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linearvgg_model = torchvision.models.vgg16(weights='DEFAULT')
parameter = torch.load("save_method2.pth")
vgg_model.load_state_dict(parameter)
print(vgg_model)

http://www.ppmy.cn/devtools/134079.html

相关文章

单片机中的BootLoader(重要的概念讲解)

文章目录 一、链接地址和执行地址1. 链接地址(Load Address)2. 执行地址(Execution Address)链接地址与执行地址的关系实际工作流程总结二、相对跳转和绝对跳转1. 相对跳转(Relative Jump)2. 绝对跳转(Absolute Jump)3. `BX` 和 `BL` 指令总结三、散列文件1. 散列文件的…

响应式网页设计--html

一&#xff0c;HTML 文档的基本结构 一个典型的 HTML 文档包含了几个主要部分&#xff0c;基本结构如下(本文以下出现的所有代码都可以套入下面示例进行测试)&#xff1a; <!DOCTYPE html> <html lang"zh"> <head><meta charset"UTF-8&q…

Java-异步方法@Async+自定义分布式锁注解Redission例子

如果你在使用 @Async 注解的异步方法中,使用了自定义的分布式锁注解(例如 @DistributedLock),并且锁到期后第二个请求并没有执行,这可能是由于以下几个原因导致的: 锁的超时时间设置不当:锁的超时时间可能设置得太短,导致锁在业务逻辑执行完成之前就已经自 动释放。…

c++写一个死锁并且自己解锁

刷算法题&#xff1a; 第一遍&#xff1a;1.看5分钟&#xff0c;没思路看题解 2.通过题解改进自己的解法&#xff0c;并且要写每行的注释以及自己的思路。 3.思考自己做到了题解的哪一步&#xff0c;下次怎么才能做对(总结方法) 4.整理到自己的自媒体平台。 5.再刷重复的类…

PostgreSQL pg-xact(clog)目录文件缺失处理

一、 背景 前些天晚上突然收到业务反馈&#xff0c;查询DB中的一个表报错 Could not open file "pg-xact/005E": No such file or directory. 两眼一黑难道是文件损坏了...登录查看DB日志&#xff0c;还好没有其他报错&#xff0c;业务也反馈只有这一个表在从库查询报…

基于STM32的智能充电桩:集成RTOS、MQTT与SQLite的先进管理系统设计思路

一、项目概述 随着电动车的普及&#xff0c;充电桩作为关键基础设施&#xff0c;其智能化、网络化管理显得尤为重要。本项目旨在基于STM32微控制器开发一款智能充电桩&#xff0c;能够实现高效的充电监控与管理。项目通过物联网技术&#xff0c;提供实时数据监测、远程管理、用…

UE5入门教程:基础操作

UE5&#xff08;虚幻引擎5&#xff09;的基础操作涵盖了多个方面&#xff0c;包括视角操作、对象操作、窗口操作、材质编辑操作等。以下是对这些基础操作的详细介绍&#xff1a; 一、视角操作 移动视角&#xff1a;按住鼠标右键可通过WASD来移动视角位置&#xff0c;E为垂直向…

Redis集群模式之Redis Sentinel vs. Redis Cluster

在分布式系统环境中&#xff0c;Redis以其高性能、低延迟和丰富的数据结构而广受青睐。随着数据量的增长和访问需求的增加&#xff0c;单一Redis实例往往难以满足高可用性和扩展性的要求。为此&#xff0c;Redis提供了两种主要的集群模式&#xff1a;Redis Sentinel和Redis Clu…