【Block总结】高效多尺度注意力EMA,超越SE、CBAM、SA、CA等注意力|即插即用

devtools/2025/2/3 17:58:39/

论文信息

标题: Efficient Multi-Scale Attention Module with Cross-Spatial Learning

作者: Daliang Ouyang, Su He, Guozhong Zhang, Mingzhu Luo, Huaiyong Guo, Jian Zhan, Zhijie Huang

论文链接: https://arxiv.org/pdf/2305.13563v2

GitHub链接: https://github.com/YOLOonMe/EMA-attention-module
在这里插入图片描述

创新点

该论文提出了一种新颖的高效多尺度注意力模块(EMA),旨在通过跨空间学习来提升特征表示的效果,同时降低计算开销。EMA模块的设计重点在于:

  • 信息保留: 在每个通道上保留信息,确保特征的完整性。
  • 计算效率: 通过重塑部分通道为批处理维度,减少计算负担。
  • 多尺度学习: 结合多尺度特征,增强模型对不同尺度信息的捕捉能力。

方法

EMA模块的核心方法包括:

  1. 通道重塑: 将部分通道重塑为批处理维度,并将通道维度分组为多个子特征,以实现更高效的信息处理。

  2. 跨维度交互: 通过跨维度交互,聚合两个并行分支的输出特征,捕获像素级的成对关系。

  3. 并行子网络: 设计多尺度并行子网络,以建立短期和长期依赖关系,从而增强特征表示能力。

在这里插入图片描述

EMA模块的信息保留与计算效率平衡

信息保留机制

EMA(Efficient Multi-Scale Attention)模块通过以下几种方式实现信息的有效保留:

  1. 通道重塑: EMA模块将部分通道重塑为批处理维度,并将通道维度分组为多个子特征。这种设计确保了每个通道的信息能够被有效保留,同时避免了通道维度的削减,从而增强了特征的表达能力[1][3]。

  2. 跨维度交互: 在EMA模块中,两个并行分支的输出特征通过跨维度交互进行聚合。这种交互机制能够捕捉到像素级的成对关系,从而进一步提升特征的丰富性和准确性[2][3]。

  3. 多尺度并行子网络: EMA模块采用了多尺度并行子网络结构,结合了1x1和3x3卷积核的特征处理。这种结构能够有效捕获不同尺度的信息,确保在特征提取过程中不会丢失重要信息[2][3]。

计算效率提升

在计算效率方面,EMA模块通过以下方式优化了计算过程:

  1. 减少计算开销: 通过将部分通道重塑为批处理维度,EMA模块能够在不显著增加计算成本的情况下,保持高效的信息处理。这种方法使得模型在处理大规模数据时更加高效[1][2]。

  2. 并行处理: EMA模块的设计允许多个子网络并行处理特征,这不仅提高了计算效率,还减少了模型的顺序处理需求,从而加快了整体计算速度[3]。

  3. 适度的模型尺寸: EMA模块的设计确保了模型的尺寸适中,适合在移动终端等资源受限的环境中部署。这种设计使得EMA模块在保持性能的同时,能够有效降低计算资源的消耗[3][2]。

EMA模块通过创新的设计实现了信息保留与计算效率的平衡。其通道重塑、跨维度交互和多尺度并行处理的策略,不仅确保了特征信息的完整性,还显著提高了计算效率。这使得EMA模块在计算机视觉任务中表现出色,尤其是在小目标检测和图像分类等应用中,展现了其广泛的应用潜力和实际意义。

效果

实验结果表明,EMA模块在多个计算机视觉任务中表现优异,尤其是在小目标检测和图像分类任务中,相较于传统的注意力机制(如ECA、CBAM、CA),EMA模块显著提高了特征表示的清晰度和准确性。

实验结果

在广泛的消融研究和实验中,EMA模块在以下数据集上进行了评估:

  • CIFAR-100
  • ImageNet-1k
  • MS COCO
  • VisDrone2019

实验结果显示,EMA模块在这些基准测试中均取得了优于现有方法的性能,尤其在小目标检测任务中,表现出明显的优势。

总结

Efficient Multi-Scale Attention Module with Cross-Spatial Learning通过创新的设计和有效的实现,成功地提升了计算机视觉任务中的特征表示能力,同时降低了计算复杂度。该模块的提出为未来的研究提供了新的思路,尤其是在需要高效处理大规模数据的应用场景中,EMA模块展现了其广泛的应用潜力。

代码

import torch
from torch import nnclass EMA(nn.Module):def __init__(self, channels, c2=None, factor=32):super(EMA, self).__init__()self.groups = factorassert channels // self.groups > 0self.softmax = nn.Softmax(-1)self.agp = nn.AdaptiveAvgPool2d((1, 1))self.pool_h = nn.AdaptiveAvgPool2d((None, 1))self.pool_w = nn.AdaptiveAvgPool2d((1, None))self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups)self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0)self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1)def forward(self, x):b, c, h, w = x.size()group_x = x.reshape(b * self.groups, -1, h, w)  # b*g,c//g,h,wx_h = self.pool_h(group_x)x_w = self.pool_w(group_x).permute(0, 1, 3, 2)hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))x_h, x_w = torch.split(hw, [h, w], dim=2)x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid())x2 = self.conv3x3(group_x)x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1))x12 = x2.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hwx21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1))x22 = x1.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hwweights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * self.groups, 1, h, w)return (group_x * weights.sigmoid()).reshape(b, c, h, w)if __name__ == "__main__":# 如果GPU可用,将模块移动到 GPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 输入张量 (batch_size, channels, height, width)x = torch.randn(1,32,40,40).to(device)# 初始化 pconv 模块dim=32block = EMA(dim,factor=8)print(block)block = block.to(device)# 前向传播output = block(x)print("输入:", x.shape)print("输出:", output.shape)

输出结果:

在这里插入图片描述


http://www.ppmy.cn/devtools/155789.html

相关文章

02 使用 海康SDK 对人脸识别设备读取事件

前言 最近朋友的需求, 是需要使用 海康sdk 连接海康设备, 进行数据的获取, 比如 进出车辆, 进出人员 这一部分是 对接海康人脸设备 获取相关事件, 并进行入库 的相关处理 测试用例 主要的处理如下 1. 设备登陆, 不同的设备可能兼容的 登陆方式不一样, 我这里设备需要使用…

【C++语言】卡码网语言基础课系列----14. 链表的基础操作II

文章目录 练习题目链表的基础操作II具体代码实现 小白寄语诗词共勉 练习题目 链表的基础操作II 题目描述: 请编写一个程序,实现以下操作: 构建一个单向链表,链表中包含一组整数数据,输出链表中的第 m 个元素&#xf…

【网络】传输层协议TCP(重点)

文章目录 1. TCP协议段格式2. 详解TCP2.1 4位首部长度2.2 32位序号与32位确认序号(确认应答机制)2.3 超时重传机制2.4 连接管理机制(3次握手、4次挥手 3个标志位)2.5 16位窗口大小(流量控制)2.6 滑动窗口2.7 3个标志位 16位紧急…

gentoo linux中安装希沃白板5

一、下载“希沃白板5” 下载地址:https://easinote.seewo.com/linux 根据自己的电脑选择合适的版本下载。这儿下载的是UOS版的X86架构。 gentoo中会自动下载到目录“~/下载 ”中。将下载的文件复制到/usr/local/src/easinote中。 二、在gentoo中安…

Qt之数据库的使用一

qt creator6.8 主要功能从数据库中读取数据,使用tableView进行显示。 qt框架中包含m/v结构 m指的是model(模型),v指的是view(视图)。这样可以使界面和数据分离开来。每当数据更新时,不会影响界面组件。 软件运行界面如下 程序分析window.…

/etc/shadow配置文件的一些符号意义说明

* 该用户永久性不能登录系统 ! 账号锁定 !! 密码锁定 图例: (如何进行账户锁定和密码锁定:账户锁定与密码锁定以及解锁-CSDN博客)

L1-006 连续因子*

1.题意 一个正整数 N 的因子中可能存在若干连续的数字。例如 630 可以分解为 3567,其中 5、6、7 就是 3 个连续的数字。给定任一正整数 N,要求编写程序求出最长连续因子的个数,并输出最小的连续因子序列。 这两句话非常重要也非常难理解&…

使用 PyTorch 实现逻辑回归:从数据到模型保存与加载

在机器学习中,逻辑回归是一种经典的分类算法,广泛应用于二分类问题。本文将通过一个简单的示例,展示如何使用 PyTorch 框架实现逻辑回归模型,从数据准备到模型训练、保存和加载,最后进行预测。 1. 数据准备 逻辑回归…