pytorch小记(十三):pytorch中`nn.ModuleList` 详解

ops/2025/3/20 2:46:57/

pytorch小记(十三):pytorch中`nn.ModuleList` 详解

  • PyTorch 中的 `nn.ModuleList` 详解
    • 1. 什么是 `nn.ModuleList`?
    • 2. 为什么不直接使用普通的 Python 列表?
    • 3. `nn.ModuleList` 的基本用法
      • 示例:构建一个包含两层全连接网络的模型
    • 4. 使用 `nn.ModuleList` 计算参数总数(与普通列表对比)
      • 示例代码
    • 5. `nn.ModuleList` 的其他应用
      • 示例:构建动态 MLP 模型
      • Transformers中的多头注意力机制
    • 6. 总结


PyTorch 中的 nn.ModuleList 详解

在构建深度学习模型时,经常需要管理多个网络层(例如多个 nn.Linearnn.Conv2d 等)。在 PyTorch 中,nn.ModuleList 是一个非常有用的容器,可以帮助我们存储多个子模块,并自动注册它们的参数。这对于确保所有参数能够参与训练非常重要。本文将详细介绍 nn.ModuleList 的作用、使用方法及与普通 Python 列表的区别,并给出清晰的代码示例。


1. 什么是 nn.ModuleList

nn.ModuleList 是一个类似于 Python 列表的容器,但专门用来存储 PyTorch 的子模块(也就是继承自 nn.Module 的对象)。其主要特点是:

  • 自动注册子模块:将 nn.Module 存储在 ModuleList 中后,这些模块的参数会自动被添加到父模块的参数列表中。这意味着当你调用 model.parameters() 时,这些子模块的参数也会被包含进去,从而参与梯度计算和优化。

  • 灵活管理:它可以像普通列表一样进行索引、迭代和切片操作,方便构建动态网络结构。

注意nn.ModuleList 不会像 nn.Sequential 那样自动定义前向传播(forward)流程。你需要在模型的 forward() 方法中手动遍历 ModuleList 并调用各个子模块。


2. 为什么不直接使用普通的 Python 列表?

虽然可以将 nn.Module 对象存储在普通列表中,但这样做有一个主要问题:
普通列表中的模块不会自动注册为父模块的子模块
这会导致:

  • 调用 model.parameters() 时无法获取这些模块的参数;
  • 优化器无法更新这些参数,从而影响模型训练。

而使用 nn.ModuleList 可以避免这个问题,因为它会自动将内部所有的模块注册到父模块中。


3. nn.ModuleList 的基本用法

下面通过一个简单的示例来说明如何使用 nn.ModuleList 构建一个简单的神经网络模型。

示例:构建一个包含两层全连接网络的模型

python">import torch
import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()# 创建一个 ModuleList 来存储各层self.layers = nn.ModuleList([nn.Linear(10, 20),  # 第 1 层:输入 10 个特征,输出 20 个特征nn.ReLU(),          # 激活层nn.Linear(20, 5)    # 第 2 层:输入 20 个特征,输出 5 个特征])def forward(self, x):# 手动遍历 ModuleList 中的每个模块,并依次调用 forwardfor layer in self.layers:x = layer(x)return x# 创建模型实例
model = MyModel()# 打印模型结构
print("模型结构:")
print(model)# 生成一组示例输入
input_tensor = torch.randn(3, 10)  # 3 个样本,每个样本 10 个特征# 得到模型输出
output = model(input_tensor)
print("\n模型输出:")
print(output)
模型结构:
MyModel((layers): ModuleList((0): Linear(in_features=10, out_features=20, bias=True)(1): ReLU()(2): Linear(in_features=20, out_features=5, bias=True))
)模型输出:
tensor([[ 0.3741,  0.0883,  0.3550, -0.3930,  0.5173],[ 0.2171, -0.0978, -0.0585, -0.4568,  0.3331],[ 0.1232, -0.1491,  0.2026, -0.0978,  0.5478]],grad_fn=<AddmmBackward0>)

说明

  • __init__() 方法中,我们将各个层放在了 nn.ModuleList 中。
  • forward() 方法中,我们使用了一个简单的 for 循环,依次调用 self.layers 中的每个子模块。

4. 使用 nn.ModuleList 计算参数总数(与普通列表对比)

为了进一步说明 nn.ModuleList 与普通列表的区别,我们分别计算一下两种方式下模型的参数总数。

示例代码

python">import torch.nn as nn# 使用 ModuleList 存储模型层
layers_ml = nn.ModuleList([nn.Linear(10, 20),nn.Linear(20, 5)
])# 计算 ModuleList 中的参数总数
ml_params = 0
for p in layers_ml.parameters():ml_params += p.numel()# 使用普通 Python 列表存储模型层
layers_list = [nn.Linear(10, 20),nn.Linear(20, 5)
]# 计算普通列表中的参数总数
list_params = 0
# 先遍历列表中的每个层
for layer in layers_list:# 再遍历每个层的参数for p in layer.parameters():list_params += p.numel()print("ModuleList 参数总数:", ml_params)
print("普通列表参数总数:", list_params)
ModuleList 参数总数: 325
普通列表参数总数: 325

说明

  • 第一个 for 循环遍历 layers_ml.parameters(),直接累加所有参数的元素数。
  • 第二部分中,我们先遍历普通列表中的每个 layer,再单独遍历每个层的参数。这样做使每一步都清晰易懂。

5. nn.ModuleList 的其他应用

示例:构建动态 MLP 模型

当网络结构比较复杂或层数不固定时,可以利用列表生成器动态构建 ModuleList

python">class DynamicMLP(nn.Module):def __init__(self, layer_sizes):super(DynamicMLP, self).__init__()# 使用 for 循环构造每一层,存储在 ModuleList 中layers = []  # 先用普通列表保存层for i in range(len(layer_sizes) - 1):linear_layer = nn.Linear(layer_sizes[i], layer_sizes[i + 1])layers.append(linear_layer)# 将普通列表转换为 ModuleListself.layers = nn.ModuleList(layers)def forward(self, x):# 遍历每一层(没有嵌套循环,逐个执行)for layer in self.layers:x = torch.relu(layer(x))return x# 创建一个动态 MLP:输入 10,隐藏层 20, 30,输出 5
dynamic_model = DynamicMLP([10, 20, 30, 5])
print("动态 MLP 模型:")
print(dynamic_model)# 测试模型
input_tensor = torch.randn(4, 10)  # 4 个样本,每个样本 10 个特征
output = dynamic_model(input_tensor)
print("\n动态 MLP 模型输出:")
print(output)

说明

  • __init__() 方法中,我们使用一个普通列表 layers 存储每个 nn.Linear 层,然后再将它转换为 nn.ModuleList
  • forward() 方法中,使用单独的 for 循环逐个调用每一层,并对输出应用 ReLU 激活函数。
  • 这种写法适用于层数动态变化的网络(例如 MLP、RNN、Transformer 中部分模块)。

Transformers中的多头注意力机制

python">class SingleHeadAttention(nn.Module):def __init__(self, embed_dim, head_dim):super().__init__()self.query = nn.Linear(embed_dim, head_dim)self.key = nn.Linear(embed_dim, head_dim)self.value = nn.Linear(embed_dim, head_dim)def forward(self, x):# 实现注意力计算逻辑...return attended_valuesclass MultiHeadAttention(nn.Module):def __init__(self, embed_dim, num_heads):super().__init__()self.head_dim = embed_dim // num_heads# 显式创建每个注意力头self.head1 = SingleHeadAttention(embed_dim, self.head_dim)self.head2 = SingleHeadAttention(embed_dim, self.head_dim)self.head3 = SingleHeadAttention(embed_dim, self.head_dim)# 使用ModuleList管理多个头self.heads = nn.ModuleList([self.head1,self.head2,self.head3])self.output_proj = nn.Linear(embed_dim, embed_dim)def forward(self, x):# 分别处理每个头head1_out = self.head1(x)head2_out = self.head2(x) head3_out = self.head3(x)# 拼接结果combined = torch.cat([head1_out, head2_out, head3_out], dim=-1)return self.output_proj(combined)

关键点解析:

  • 显式声明每个注意力头(避免循环)

  • 使用ModuleList统一管理注意力头

  • 在forward中分别调用每个头

  • 保持各头独立性,便于后续调试


6. 总结

  • nn.ModuleList 是专门用于存储多个子模块的容器,它会自动注册子模块,确保所有参数能参与训练。
  • 与普通 Python 列表相比,ModuleList 可以直接通过 model.parameters() 获取其中所有参数,从而方便地进行优化。
  • 使用 ModuleList 时,前向传播需要手动遍历其中的模块,这提供了更大的灵活性,但也要求开发者理解循环过程。

http://www.ppmy.cn/ops/167179.html

相关文章

TypeScript + Vue:类风格组件如何引领前端新潮流?

&#x1f680; TypeScript Vue&#xff1a;用类风格组件打造“假货比对”神器&#xff01;&#x1f31f; 2025 年&#xff0c;前端开发早已进入“类型安全 模块化”的黄金时代。TypeScript (TS) 的类风格组件正在席卷 Vue 社区&#xff0c;为开发者带来更优雅、更强大的编码体…

深度剖析淘宝拍立淘按图搜索商品API技术规范

在电商技术蓬勃发展的当下&#xff0c;淘宝的拍立淘功能以其独有的按图搜索商品特性&#xff0c;为用户打造了便捷且新颖的购物体验。这一功能背后的强大支撑 —— 淘宝拍立淘按图搜索商品 API&#xff0c;不仅革新了传统电商搜索模式&#xff0c;还为开发者和企业开拓了创新应…

学习springboot 的自动配置原理

前言 为什么要学习springboot 的自动配置原理&#xff1f; 1学习 自定义成starter 的前提 实际开发中&#xff0c;我们如果定义公共的组件给团队使用&#xff0c;为了让他们使用方便就自定义成starter。而想要学习starter ,就要先了解springboot 的自动配置原理 2 面试需要 了…

FastJson:JSON JSONObject JSONArray详解以及SimplePropertyPreFilter 的介绍

FastJson&#xff1a;JSON JSONObject JSONArray详解以及SimplePropertyPreFilter 的介绍 FastJson是阿里巴巴开发的一款专门用于Java开发的包&#xff0c;实现Json对象&#xff0c;JavaBean对&#xff0c;Json字符串之间的转换。 文章目录 FastJson&#xff1a;JSON JSONObje…

C#运算符与表达式:从入门到游戏伤害计算实践

Langchain系列文章目录 01-玩转LangChain&#xff1a;从模型调用到Prompt模板与输出解析的完整指南 02-玩转 LangChain Memory 模块&#xff1a;四种记忆类型详解及应用场景全覆盖 03-全面掌握 LangChain&#xff1a;从核心链条构建到动态任务分配的实战指南 04-玩转 LangChai…

【为什么游戏能使人上瘾】

为什么游戏能使人上瘾&#xff0c;而工作不会&#xff1f;——从神经科学、心理学与行为设计学拆解 一、多巴胺回路的“即时收割” vs “延迟满足” 游戏的神经劫持机制 即时反馈闭环&#xff1a;游戏设计遵循“行为→奖励→强化”的秒级循环。例如&#xff1a; • 击杀小怪→金…

链表·简单归并

cur->next la; //将 p指针所指向的链表节点的 next 指针&#xff08;也就是 p 节点的下一个节点的指针&#xff09;指向 l1 所指向的链表节点。简单来说&#xff0c;就是把 la 节点连接到 p 节点的后面&#xff0c;更新了链表的连接关系。 p la; //将p指针的值更新为 la …

【服务器】RAID0、RAID1、RAID5、RAID6、RAID10异同与应用

目录 ​编辑 一、RAID概述 1.1 磁盘阵列简介 1.2 功能 二、RAID级别 2.1 RAID 0&#xff08;不含校验与冗余的条带存储&#xff09; 2.2 RAID1&#xff08;不含校验的镜像存储&#xff09; 2.3 RAID 5 &#xff08;数据块级别的分布式校验条带存储&#xff09; 4、RAI…