PyTorch构建模型网络结构的6种方式

news/2024/9/18 15:03:07/ 标签: pytorch, 人工智能, python

PyTorch提供了多种方式来构建模型的网络结构,我尝试总结一下,有如下6种常见方式(可能还有我没注意到的,欢迎补充)。我们平时写代码并不一定需要掌握全部方式,但是多了解一些,对于阅读理解别人的代码显然是有帮助的。

1,继承nn.Module类

这是构建自定义模型最基础也是最常见的方法。通过继承torch.nn.Module类,并在子类中定义__init__方法来初始化模型的各个层,以及在forward方法中定义数据的前向传播路径。

python">import torch  
import torch.nn as nn  
import torch.nn.functional as F  # 继承nn.Module  
class SimpleCNN(nn.Module):  def __init__(self, num_classes=10):  super(SimpleCNN, self).__init__()  # 定义卷积层  self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1)  self.relu1 = nn.ReLU()  self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)  self.conv2 = nn.Conv2d(32, 64, 3, padding=1)  self.relu2 = nn.ReLU()  self.pool2 = nn.MaxPool2d(2, 2)  # 定义全连接层  self.fc1 = nn.Linear(64 * 8 * 8, 128)  self.relu3 = nn.ReLU()  self.fc2 = nn.Linear(128, num_classes)  def forward(self, x):  x = self.conv1(x)  x = self.relu1(x)  x = self.pool1(x)  x = self.conv2(x)  x = self.relu2(x)  x = self.pool2(x)  # 展平特征图  x = x.view(-1, 64 * 8 * 8)  # 全连接层  x = self.fc1(x)  x = self.relu3(x)  x = self.fc2(x)  return x  # 实例化模型并打印结构  
model = SimpleCNN(num_classes=10)  
print(model)  # 假设有一个输入张量,测试模型  
input_tensor = torch.randn(1, 3, 32, 32)  
output = model(input_tensor)  
print(output.shape)  # 应该是[1, 10],表示10个类别的输出

优点:

1)高度灵活:允许用户定义任意复杂的前向传播逻辑,并可以轻松地插入自定义的操作或层。

2)功能强大:通过继承nn.Module,用户可以充分利用PyTorch提供的各种功能,如参数管理、模型保存/加载、GPU加速等。

缺点:

1)在模型层数多结构复杂时,只使用nn.Module类来编写会显得凌乱,后期难以维护

2)相比nn.Sequential,代码量更多,需要定义一个类,并且手动编写前向传播部分

2,使用nn.Sequential

对于顺序连接的层,可以使用nn.Sequential来简化模型的构建。nn.Sequential接受一个层列表作为输入,并自动定义前向传播。

python">import torch  
import torch.nn as nn  
import torch.nn.functional as F  # 定义一个简单的CNN模型,使用nn.Sequential  
class SimpleCNN(nn.Sequential):  def __init__(self, num_classes=10):  super(SimpleCNN, self).__init__()  # 添加卷积层  self.add_module('conv1', nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1))  self.add_module('relu1', nn.ReLU())  self.add_module('pool1', nn.MaxPool2d(kernel_size=2, stride=2))  self.add_module('conv2', nn.Conv2d(32, 64, 3, padding=1))  self.add_module('relu2', nn.ReLU())  self.add_module('pool2', nn.MaxPool2d(2, 2))  # 添加全连接层,注意需要先flatten特征图  self.add_module('flatten', nn.Flatten())  self.add_module('fc1', nn.Linear(64 * 8 * 8, 128))  # 修正了这里的输入维度  self.add_module('relu3', nn.ReLU())  self.add_module('fc2', nn.Linear(128, num_classes))  # 实例化模型并打印结构  
model = SimpleCNN(num_classes=10)  
print(model)  # 假设有一个输入张量  
input_tensor = torch.randn(1, 3, 32, 32)  
output = model(input_tensor)  
print(output.shape)  # 应该是[1, 10],表示10个类别的输出

注:其中add.module是nn.Sequential中的一个方法用于向 nn.Sequential 容器中添加一个模块(即一个层或一个子网络)。当你创建一个 nn.Sequential 实例时,你可以通过调用 self.add_module 方法来逐个添加你想要的层。这个方法接受两个参数:name和module,module是一个 nn.Module 的实例,表示要添加的层或子网络。add.module不仅可以用在初始化方法中,还可以动态添加网络结构。

优点:

1)简单直观,代码量少,易于维护

缺点:

1)灵活性不足:对于需要复杂前向传播逻辑或者非线性层次结构(如跳跃结构或者分支结构)的模型,只用nn.Sequential不方便

2)调试不便:因为所有层都被封装在Sequential中

3)自定义操作受限:不方便插入自定义的逻辑和操作

3,结合使用nn.Module和nn.Sequential

使用nn.Module构建网络,但其中每个block都用nn.Sequential构建的方式,实际上结合了nn.Sequential和nn.Module两者的特点

python">import torch  
import torch.nn as nn  
import torch.nn.functional as F  # 定义一个简单的CNN模型,结合使用nn.Module和nn.Sequential  
class SimpleCNN(nn.Module):  def __init__(self, num_classes=10):  super(SimpleCNN, self).__init__()  # 使用nn.Sequential定义卷积层部分  self.features = nn.Sequential(  nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1),  nn.ReLU(),  nn.MaxPool2d(kernel_size=2, stride=2),  nn.Conv2d(32, 64, 3, padding=1),  nn.ReLU(),  nn.MaxPool2d(2, 2)  )  # 定义全连接层部分  self.fc1 = nn.Linear(64 * 8 * 8, 128)  self.relu3 = nn.ReLU()  self.fc2 = nn.Linear(128, num_classes)  def forward(self, x):  # 通过卷积层部分  x = self.features(x)  # 展平特征图  x = x.view(-1, 64 * 8 * 8)  # 通过全连接层部分  x = self.fc1(x)  x = self.relu3(x)  x = self.fc2(x)  return x  # 实例化模型并打印结构  
model = SimpleCNN(num_classes=10)  
print(model)  # 假设有一个输入张量  
input_tensor = torch.randn(1, 3, 32, 32)  
output = model(input_tensor)  
print(output.shape)  # 应该是[1, 10],表示10个类别的输出

对于案例里这种过于简单的模型,这种混合方式的优势还不明显。但是对于复杂模型,这种组合的实现方式相比起前两种更常见。

4,使用nn.ModuleDict

nn.ModuleDict 是 PyTorch 中的一个类,它继承自 nn.Module,用于存储模块(modules)的字典。与普通的 Python 字典不同,nn.ModuleDict 中的模块会被自动注册为参数,这样它们就可以被识别为模型的一部分,并且在调用 .parameters() 或 .to(device) 等方法时,这些模块中的参数也会被包含在内。

nn.ModuleDict 的用法很简单,你可以像使用普通字典一样使用它,但是键(key)必须是字符串,值(value)必须是 nn.Module 的实例。

python">import torch  
import torch.nn as nn  
import torch.nn.functional as F  
from collections import OrderedDict  # 定义一个简单的CNN模型,使用nn.OrderedDict来组织层  
class SimpleCNN(nn.Module):  def __init__(self, num_classes=10):  super(SimpleCNN, self).__init__()  # 使用nn.OrderedDict定义所有层#不需要注册  self.layers = nn.ModuleDict({  'conv1': nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1),  'relu1': nn.ReLU(),  'pool1': nn.MaxPool2d(kernel_size=2, stride=2),  'conv2': nn.Conv2d(32, 64, 3, padding=1),  'relu2': nn.ReLU(),  'pool2': nn.MaxPool2d(2, 2),  'flatten': nn.Identity(),  # 使用nn.Identity作为占位符,实际展平操作在forward中实现  'fc1': nn.Linear(64 * 8 * 8, 128),  'relu3': nn.ReLU(),  'fc2': nn.Linear(128, num_classes)  })  def forward(self, x):  # 顺序通过所有层  for name,layer in self.layers.items():  x = layer(x)  if name == 'pool2':  # 在池化后检查是否需要展平  x = x.view(-1, 64 * 8 * 8)  # 展平操作print(x.shape)return x  # 实例化模型并打印结构  
model = SimpleCNN(num_classes=10)  
print(model)  # 假设有一个输入张量  
input_tensor = torch.randn(1, 3, 32, 32)  
output = model(input_tensor)  
print(output.shape)  # 应该是[1, 10],表示10个类别的输出

优点:

1)会自动注册其中的模块,但是这一点也可能成为缺点,主要看需求。如果调用的是pytoch自带的层,用自动注册更省事。

缺点:

1)会自动注册其中的模块同样可能成为一个确定,如果需要自己编写层,就需要手动注册。

2)出错后定位问题位置和调试相对更困难:前三种实现方式里都有层次化的结构,而用nn.ModuleDict来实现会缺乏这样的结构信息。另外nn.ModuleDict允许动态修改和删除模块,这会增加出错几率和调试难度。

5,使用nn.OrderedDict

当需要对层进行命名以便后续访问时,还可以使用collections.OrderedDict

python">import torch  
import torch.nn as nn  
import torch.nn.functional as F  class SimpleCNN(nn.Module):  def __init__(self, num_classes=10):  super(SimpleCNN, self).__init__()  # 使用nn.OrderedDict定义所有层的顺序  self.layers = nn.OrderedDict([  ('conv1', nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1)),  ('relu1', nn.ReLU()),  ('pool1', nn.MaxPool2d(kernel_size=2, stride=2)),  ('conv2', nn.Conv2d(32, 64, 3, padding=1)),  ('relu2', nn.ReLU()),  ('pool2', nn.MaxPool2d(2, 2)),  ('flatten', nn.Flatten()),  # 使用nn.Flatten进行展平操作  ('fc1', nn.Linear(64 * 8 * 8, 128)),  ('relu3', nn.ReLU()),  ('fc2', nn.Linear(128, num_classes))  ])  # 将layers中的模块注册到当前Module  for name, module in self.layers.items():  self.add_module(name, module)  def forward(self, x):  # 顺序通过所有层  for name, layer in self.layers.items():  x = layer(x)  if name == 'flatten':  # 在展平层后打印形状  print(x.shape)  return x  # 实例化模型并打印结构  
model = SimpleCNN()  
print(model)  # 假设有一个输入张量  
input_tensor = torch.randn(1, 3, 32, 32)  
output = model(input_tensor)  
print(output.shape)  # 应该是[1, 10],表示10个类别的输出

在PyTorch中,即使不用nn.OrderedDict也可以给每一层命名。但使用OrderedDict来管理层的顺序和注册仍然有其优势,特别是在构建复杂模型或需要动态修改模型结构时。

另外,我们可以注意到nn.OrderedDict和nn.ModuleDict都是用字典的来存储模块,那么他们的区别是什么呢?

最主要的区别有两点:1)nn.OrderedDict是严格保持元素顺序的,而nn.Module在原本的实现里是不保持元素顺序的,但是在python3.7及之后的版本里已经改为保持元素的插入顺序。

2)nn.OrderedDict其实不是pytoch中特有的类,而是python中的类,所以它是不会自动注册模块的。需要手动注册。

优缺点:

相比nn.ModuleDict,nn.OrderedDict不会自动注册模块,这个特点视情况可能成为优点或缺点。

6,调用预训练好的模块拼装成网络模型

这种方式在深度学习领域非常常见,特别是在迁移学习和微调(fine-tuning)的场景中。

示例代码如下,我们取预训练的resnet18前两个block,和我们自己实现的分类头组合成一个全新的模型。

python">import torch  
import torch.nn as nn  
from torchvision import models  class MyPartialResNet18(nn.Module):  def __init__(self, num_classes=10):  super(MyPartialResNet18, self).__init__()  # 加载预训练的resnet18模型  self.resnet18 = models.resnet18(pretrained=True)  # 冻结整个resnet18的参数  for param in self.resnet18.parameters():  param.requires_grad = False  # 只保留前两个block  self.features = nn.Sequential(  self.resnet18.conv1,  self.resnet18.bn1,  self.resnet18.relu,  self.resnet18.maxpool,  self.resnet18.layer1, self.resnet18.layer2  )self.gmp=  nn.AdaptiveMaxPool2d((1, 1))self.flatten = nn.Flatten()# 自定义一个全连接层  self.fc = nn.Linear(self.resnet18.layer2[-1].conv2.out_channels, num_classes)  def forward(self, x):  x = self.features(x)  x = self.gmp(x)x = self.flatten(x)x = self.fc(x)  return x  # 创建模型实例  
model = MyPartialResNet18(num_classes=10)  # 打印模型结构  
print(model)  # 假设你有一个输入tensor x  
x = torch.randn(1, 3, 224, 224)  
# 输出模型的预测  
output = model(x)  
print(output.shape)  # 应该是[1, 10],表示10个类别的预测

优点:可以快速实现,可以利用现有的预训练参数

缺点:缺乏灵活性

总之以上方法都有各自优缺点,各位按照自己的实际需求选择。


http://www.ppmy.cn/news/1516753.html

相关文章

服务器主动推送的方法

目录 1.长轮询(Long Polling)2.WebSockets3.Server-Sent Events(SSE)4.HTTP2 Server Push 服务器如何主动推送数据 在传统的网络通信中,客户端(如浏览器)通常需要通过向服务器发起请求来获取数据…

快速判断一个项目是Spring MVC框架还是Spring Boot框架

1. 查看项目的启动类 Spring Boot: 通常有一个主类,包含 SpringBootApplication 注解,并且有一个 main 方法来启动应用程序。 SpringBootApplication public class Application {public static void main(String[] args) {SpringApplication.run(Appli…

作业训练三编程题13. 导弹防御系统

【问题描述】 某国为了防御敌国的导弹袭击,开发出一种导弹拦截系统。但是这种导弹拦截系统有一个缺陷:虽然它的第一发炮弹能够到达任意的高度,但是以后每一发炮弹都不能高于前一发的高度。某天,雷达捕捉到敌国的导弹来袭&#xf…

python脚本请求数量达到上限,http请求重试问题例子解析

在使用Python的requests库进行HTTP请求时,可能会遇到请求数量达到上限,导致Max retries exceeded with URL的错误。这通常发生在网络连接不稳定、服务器限制请求次数、或请求参数设置错误的情况下。以下是一些解决该问题的策略: 增加重试次数…

并发服务器开发基础

一、服务器模型 1. 单循环服务器: 单循环服务器在同一时刻只能处理一个客户端的请求。由于其结构简单,适合低负载的场景,但在并发请求增加时可能导致性能问题。 2. 并发服务器模型: 并发服务器可以同时响应多个客户端…

本地环境注入jupyter:无法在jupyter选择已经创建的conda环境?快来看下解决办法(jupyter notebook选择已创建环境)

1、WinR打开本机cmd命令行 2、运行 conda activate 本地已创建的环境名称 3、运行 conda install ipykernel 4、运行 python -m ipykernel install --user --name 本地环境名称 --display-name "在jupyter上显示的环境名称" 就可以在jupyter notebook中看到环…

谷粒商城实战笔记-250-商城业务-消息队列-RabbitMQ安装-Docker

一,docker安装RabbitMq RabbitMQ 是一个开源的消息代理软件,广泛用于实现异步通信和应用程序解耦。 使用 Docker 容器化技术可以简化 RabbitMQ 的安装和部署过程。 以下是使用 Docker 安装 RabbitMQ 的详细步骤。 步骤 1: 安装 Docker 如果您的系统…

如何解决:Failed to start jenkins.service: Unit not found.

当在 CentOS 上尝试启动 Jenkins 服务时,出现 Failed to start jenkins.service: Unit not found 的错误,这通常表示 Jenkins 服务未安装或未正确配置。请按照以下步骤进行排查和解决: 解决步骤 检查 Jenkins 是否已安装: 确认 J…

如何使用ssm实现旅游网站的设计与实现

TOC ssm150旅游网站的设计与实现jsp 绪论 1.1 研究背景 当前社会各行业领域竞争压力非常大,随着当前时代的信息化,科学化发展,让社会各行业领域都争相使用新的信息技术,对行业内的各种相关数据进行科学化,规范化管…

ComfyUI 常用的节点

总的来说,如果可以的话 最佳实践是直接访问每个节点仓库,仔细阅读作者提供的文档和说明。然后,手动执行 git clone 来获取仓库的代码。 接着,你可以通过手动执行 pip install -r requirements.txt 来安装每个项目的依赖。这种方法…

【Linux】第十八章 Reactor模式

文章目录 Reactor模式epoll ET服务器(Reactor模式)设计思路Epoller.hppSock.hppProtocol.hppService.hppTcpServer.hpp-重点Connection类TcpServer类服务器框架TcpServer构造AddConnection函数SetNonBlock 函数Accepter函数IsExists函数TcpRecver函数Tcp…

[oeasy]python031_[趣味拓展]unix起源_Ken_Tompson_Ritchie_multics

[趣味拓展]unix起源_Ken_Tompson_Ritchie_multics 🥋 回忆上次内容 上次 动态设置了 断点 断点 可以把代码 切成一段一段的可以 更快地调试 调试的目的 是 去除 bug 别害怕 bug 一步步 总能找到 bug这 就是 程序员基本功 调试 debug 在bug出现的时候 甚至…

docker-harbor私有仓库部署和管理

harbor:开源的企业级的docker仓库软件 仓库:私有仓库 公有仓库 (公司内部一般都是私有仓库) habor 是有图形化的,页面UI展示的一个工具,操作起来很直观。 harbor每个组件都是由容器构建的,所…

高效的数据恢复软件介绍给大家!

数据丢失可太烦人了,在工作中我们经常会遇到数据丢失的情况,那么你知道数据丢失怎么找回来吗?当然找的回来啦!需要用上高效且有用的数据恢复工具。那么,今天就要给大家介绍两个好用的数据恢复工具,可以将您…

5个常见问答 | 1+X证书《大数据应用开发(Python)》

1、 1X大数据应用开发(Python)哪些人群可以考? 全日制在读的中高职学校、应用型本科、本科层次职业教育试点学校院校的学生,有意向从事与证书相关岗位的社会人士都可考取该证书。 2、1X大数据应用开发(Python&am…

网络udp及ipc内存共享

大字符串找小字符串 调试 1. 信号处理函数注册:•一旦使用 signal 函数注册了信号处理函数,该函数就会一直有效,直到程序结束或者显式地取消注册。2. 注册多次的影响:•如果多次注册同一信号的处理函数,最后一次注册的…

【手写数据库内核组件】0303 数据缓存池(二) 缓存块使用前需要固定,缓存加载与无效,无锁的替换算法

0303 数据缓存池(二) ​专栏内容: postgresql使用入门基础手写数据库toadb并发编程个人主页:我的主页 管理社区:开源数据库 座右铭:天行健,君子以自强不息;地势坤,君子以厚德载物. 文章目录 0303 数据缓存池(二)一、概述 二、缓存块操作原理 2.1 缓存块的读写访问 2.2 无…

C学习(数据结构)-->实现链式结构二叉树

目录 一、链式二叉树结构 二、实现 1、申请新结点 2、前、中、后序遍历 1)前序遍历 例: 2)中序遍历 3)后序遍历 3、结点个数 1)二叉树结点个数 例:​编辑 2)二叉树叶子结点个数 3&…

网络排名变差算法在充电桩计量可信度评价中的应用AcrelCloud-9000安科瑞充电柱收费运营云平台

摘要:网络排名变差算法是指根据充电交易流水数据构造桩车网络,利用复杂网络的投票智慧而非传统的物理实验来获得对量值的信心。将排名变差算法用于桩车网络计算中,旨在检定合格的充电桩对其他充电桩排名变化的影响,这种影响以电动…

计算机毕业设计选题推荐-OA办公管理系统-Java/Python项目实战

✨作者主页:IT毕设梦工厂✨ 个人简介:曾从事计算机专业培训教学,擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Py…