网络模型的保存与读取

embedded/2024/9/23 23:03:28/

文章目录

    • 一、模型的保存
    • 二、文件的加载
    • 三、模型加载时容易犯的陷阱

一、模型的保存

方式1:torch.save(vgg16, “vgg16_method1.pth”)

import torch
import torchvision.modelsvgg16 = torchvision.models.vgg16(pretrained=False)
torch.save(vgg16, "vgg16_method1.pth")

如果运行报错:UserWarning: Arguments other than a weight enum or None for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing weights=None或者The parameter ‘pretrained‘ is deprecated since 0.13 and may be removed in the future

原因是在 PyTorch 的 torchvision 库中,从版本 0.13 开始,pretrained 参数已经被弃用,取而代之的是 weights 参数。这个改变是为了提供更丰富的预训练模型选择。当你尝试使用 vgg16(pretrained=False) 时,你收到了一个警告,告诉你 pretrained 参数已经不再被使用,并且建议你使用 weights 参数。

要解决这个问题,你应该使用 weights 参数来代替 pretrained。

修正代码:

import torch
import torchvision.modelsvgg16 = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)
torch.save(vgg16, "vgg16_method1.pth")

运行代码:
在这里插入图片描述
可以看到多了一个新文件vgg16_method1.pth

该方式1保存的网路模型不仅保存了网络模型的一种结构,它也保存了模型当中的一些参数

方式2:把模型的参数保存成字典(dict)形式

import torch
import torchvision.modelsvgg16 = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)
torch.save(vgg16.state_dict(),"vgg16_method2.pth")

运行结果:
在这里插入图片描述
方式1与方式2对比:

方式1保存的是模型的结构+模型的参数,方式2保存的只是模型的参数(官方推荐的保存方式)

官方推荐的原因是当保存一个大的模型时候,方式2所用的空间更小

我们可以查看一下两种保存方式的文件大小:
在这里插入图片描述
因为vgg这个模型本身就不大,所以文件大小差距并不明显,但方式2足足小了7kb!要是在大模型下节省空间这点会尤其明显。

二、文件的加载

代码:

import torchmodel = torch.load("vgg16_method1.pth")
print(model)

运行结果:
在这里插入图片描述

通过将save与load的文件debug运行:
在这里插入图片描述
能够发现两者都是一样的,说明被完整加载出来。

通过上述步骤可以看到模型中的参数也一同保存下来了。

加载方式2保存的模型:

import torchmodel = torch.load("vgg16_method2.pth")
print(model)

运行结果:
在这里插入图片描述
可以看到方式2形式是一个个字典形式.

方式2从字典形式想要恢复网络模型结构则需要:

import torch
import torchvision# 创建一个VGG16模型实例,参数pretrained=False表示不加载预训练的权重。
vgg16 = torchvision.models.vgg16(pretrained=False)# 加载之前保存的模型权重,这些权重保存在名为"vgg16_method2.pth"的文件中。
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))# 打印出模型的结构,这样可以看到模型的各个层和参数。
print(vgg16)

运行结果:
在这里插入图片描述
可以看到把模型参数成功加载出来了

三、模型加载时容易犯的陷阱

保存一个自己写的网络模型

import torch
import torchvision.models
from torch import nnclass Sen(nn.Module):def __init__(self):super(Sen, self).__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=3)def forward(self, x):x = self.conv1(x)return xsen = Sen()
torch.save(sen, "sen_method1.pth")

运行结果:
在这里插入图片描述

因为用方式1进行保存,故使用方式1的方法进行加载:

import torch
import torchvisionmodel = torch.load("sen_method1.pth")
print(model)

运行结果:
在这里插入图片描述
可以看到发生了报错,报错的意思是加载的时候没有找到Sen这个类

解决方法是将类复制到加载代码中:

import torch
import torchvisionclass Sen(nn.Module):def __init__(self):super(Sen, self).__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=3)def forward(self, x):x = self.conv1(x)return xmodel = torch.load("sen_method1.pth")
print(model)

注意不需要写sen = Sen()这一代码

运行代码:
在这里插入图片描述

也就是用自己写的网络模型不同于现有的网络模型,需要进行导入才能正常加载出来!

或者也可以用import的方法加载自己写的网络模型,那么就不需要老是复制粘贴

通过from model_save import *加载:

import torch
from model_save import *model = torch.load("sen_method1.pth")
print(model)

运行结果:
在这里插入图片描述


http://www.ppmy.cn/embedded/115799.html

相关文章

在线文档搜索服务测试报告

目录 1. 项目背景: 2. 项目功能: 3. 测试计划: 1. 项目背景: 1.1 在线搜索服务的前端主要一下几个功能, 分别是进入搜索引擎界面(有提示输入关键词信息); 进行输入关键词的界面, 以及显示有关关键词的文档url, 点击跳转至目标文档的界面; 1.2 该在线搜索服务的文档可以实现用…

指针修仙之实现qsort

文章目录 回调函数什么是回调函数回调函数的作用 库函数qsort使用qsort函数排序整形使用qsort函数排序结构体 qsort函数模拟实现说明源码and说明 回调函数 什么是回调函数 回调函数就是⼀个通过函数指针调⽤的函数。 如果你把函数的指针(地址)作为参数…

如何进入电脑BIOS

前言 在日常使用电脑的过程中,有时我们需要进入BIOS(基本输入输出系统)来调整设置,比如更改启动顺序、调整系统日期时间或是优化硬件配置。BIOS是计算机启动时最先运行的程序之一,它位于主板上的一个ROM芯片中。下面&…

【监控】【Nginx】使用 Prometheus + Grafana 监控 Nginx

目录 一、什么是 Prometheus 和 Grafana?二 、准备工作步骤 1:安装 Prometheus1. 下载并解压 Prometheus2. 编辑 Prometheus 配置(prometheus.yml)3. 启动 Prometheus 步骤 2:安装 Grafana1. 安装 Grafana2. 启动 Graf…

Python ORM 框架 SQLModel 快速入门教程

创建模型 import sqlmodel import typingclass Hero(sqlmodel.SQLModel, tableTrue):id: typing.Optional[int] sqlmodel.Field(defaultNone, primary_keyTrue)name: strreal_name: strage: typing.Optional[int] None创建表 import sqlmodel import typingclass Hero(sqlm…

win11 此应用无法在你的电脑上运行 若要找到适用于你的电脑的版本,请咨询软件发布者

在Windows 11上遇到“此应用无法在你的电脑上运行”的问题,通常意味着该应用程序与Windows 11不兼容,或者你的系统设置阻止了应用程序的运行。以下是一些解决这个问题的步骤: 操作系统不支持 某些应用程序可能尚未更新以支持Windows 11&…

人工智能学习思路(新生新手小白的指引手册-超详细版)

一、前言 该内容仅作为个人笔记使用,希望看到的各位能有所获,博主有误的地方,各位可以在评论区有所指正 二、正文 1、0基础小白(连计算机是啥都不知道) 首先对于计算机这块都没怎么涉猎的新生来说,首先…

Linux入门学习:Git

文章目录 1. 创建仓库2. 仓库克隆3. 上传文件4. 相关问题4.1 git进程阻塞4.2 git log4.3 上传的三个步骤在做什么4.4 配置邮箱/用户名 本文介绍如何在Linux操作系统下简单使用git,对自己的代码进行云端保存。 1. 创建仓库 🔹这里演示gitee的仓库创建。…