【Block总结】Shuffle Attention,新型的Shuffle注意力|即插即用

devtools/2025/2/5 4:06:27/

一、论文信息

  • 标题: SA-Net: Shuffle Attention for Deep Convolutional Neural Networks

  • 论文链接: arXiv

  • 代码链接: GitHub
    在这里插入图片描述

二、创新点

Shuffle Attention(SA)模块的主要创新在于高效结合了通道注意力和空间注意力,同时通过通道重排技术降低计算复杂度。具体创新点包括:

  • 通道分组: 将输入特征图的通道维度分成多个组,允许并行处理。
  • 通道重排: 通过打乱通道顺序,增强模型对通道特征的表达能力。
  • 融合注意力机制: 同时计算通道和空间注意力,提升特征表示的能力。
    在这里插入图片描述

三、方法

Shuffle Attention模块的实现步骤如下:

  1. 特征分组: 将输入特征图的通道数通过参数G进行分组,得到多个子特征图。

  2. 通道注意力:

    • 对每组特征使用全局平均池化(GAP),生成通道级统计数据。
    • 通过学习的权重和偏置调整每组特征的通道重要性,并使用Sigmoid激活函数应用于特征图。
  3. 空间注意力:

    • 使用组归一化(GroupNorm)计算每组特征的空间注意力。
    • 经过Sigmoid激活后,将空间注意力应用于特征图的空间维度。
  4. 通道重排: 在计算完通道和空间注意力后,使用通道重排操作打乱通道顺序,以增强特征表达能力。

  5. 输出: 返回经过重排后的输出特征图。

Shuffle Attention与传统注意力机制的优势比较

Shuffle Attention(SA)是一种新型的注意力机制,旨在提高深度卷积神经网络的性能。与传统的注意力机制相比,Shuffle Attention在多个方面展现出显著的优势,尤其是在计算效率和模型性能方面。

优势如下:

  1. 高效的计算性能
    Shuffle Attention通过将输入特征图的通道维度分成多个组,并对每个组进行并行处理,从而显著降低了计算复杂度。传统的注意力机制通常需要在全通道上进行计算,导致计算量大且效率低下。SA的设计使得在保持性能的同时,减少了参数量和计算量。例如,在ResNet50上,SA的参数量从300M降至25.56M,计算量从4.12 GFLOPs降至2.76 GFLOPs[2]。

  2. 融合通道和空间注意力
    Shuffle Attention同时结合了通道注意力和空间注意力,能够更全面地捕捉特征之间的依赖关系。传统的注意力机制往往是将这两种注意力机制分开处理,未能充分利用它们之间的相互关系。SA通过“通道重排”操作,促进了不同组之间的信息交流,从而提升了特征的表达能力和模型的整体性能。

  3. 增强特征表达能力
    Shuffle Attention通过通道重排和特征分组的方式,增强了模型对特征的表达能力。传统注意力机制在处理特征时,可能会忽略某些重要的通道或空间信息,而SA通过并行处理和重排,确保了所有特征都能得到充分利用,从而提高了模型的准确性。

  4. 适应性强
    Shuffle Attention的设计使其能够灵活适应不同的网络架构和任务需求。它可以作为一个轻量级的模块,方便地集成到现有的深度学习模型中,而不需要对整个网络结构进行大幅修改。这种灵活性使得SA在各种计算机视觉任务中表现出色,包括图像分类、目标检测和实例分割等。

四、效果

Shuffle Attention模块在多个计算机视觉任务中表现出色,尤其是在图像分类和目标检测任务中。通过有效的特征提取和信息融合,SA模块能够在保持较低计算复杂度的同时,显著提高模型的性能。

五、实验结果

在ImageNet-1k、MS COCO等数据集上的实验结果表明:

  • 准确率提升: SA模块在与ResNet等主干网络结合时,Top-1准确率提升超过1.34%。
  • 计算复杂度降低: 相比于传统的注意力机制,SA模块在参数量和计算量上均显著减少,例如在ResNet50上,参数量从300M降至25.56M,计算量从4.12 GFLOPs降至2.76 GFLOPs。

六、总结

Shuffle Attention模块通过创新的特征分组和通道重排机制,有效地结合了通道和空间注意力,显著提升了深度卷积神经网络的性能。该模块不仅在多个基准测试中表现优异,还展示了在实际应用中的潜力,尤其是在资源受限的环境中。未来的研究可以进一步探索SA模块在更复杂任务中的应用效果。

代码

代码有错误,我做了修改。代码如下:


import torch
import torch.nn.functional
import torch.nn.functional as F
from torch import nn
from torch.nn.parameter import Parameterclass sa_layer(nn.Module):"""Constructs a Channel Spatial Group module.Args:k_size: Adaptive selection of kernel size"""def __init__(self, channel, groups=64):super(sa_layer, self).__init__()self.groups = groupsself.avg_pool = nn.AdaptiveAvgPool2d(1)self.cweight = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))self.cbias = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))self.sweight = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))self.sbias = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))self.sigmoid = nn.Sigmoid()self.gn = nn.GroupNorm(channel // (2 * groups), channel // (2 * groups))@staticmethoddef channel_shuffle(x, groups):b, c, h, w = x.shapex = x.reshape(b, groups, -1, h, w)x = x.permute(0, 2, 1, 3, 4)# flattenx = x.reshape(b, -1, h, w)return xdef forward(self, x):b, c, h, w = x.shapex = x.reshape(b * self.groups, -1, h, w)x_0, x_1 = x.chunk(2, dim=1)# channel attentionxn = self.avg_pool(x_0)xn = self.cweight * xn + self.cbiasxn = x_0 * self.sigmoid(xn)print(xn)# spatial attentionxs = self.gn(x_1)xs = self.sweight * xs + self.sbiasxs = x_1 * self.sigmoid(xs)# concatenate along channel axisout = torch.cat([xn, xs], dim=1)out = out.reshape(b, -1, h, w)out = self.channel_shuffle(out, 2)return outif __name__ == "__main__":dim=64# 如果GPU可用,将模块移动到 GPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 输入张量 (batch_size, height, width,channels)x = torch.randn(2,dim,40,40).to(device)# 初始化 sa_layer模块block = sa_layer(dim,8)print(block)block = block.to(device)# 前向传播output = block(x)print("输入:", x.shape)print("输出:", output.shape)

输出结果:
在这里插入图片描述


http://www.ppmy.cn/devtools/156169.html

相关文章

基于微信小程序高校课堂教学管理系统 课堂管理系统微信小程序(源码+文档)

目录 一.研究目的 二.需求分析 三.数据库设计 四.系统页面展示 五.免费源码获取 一.研究目的 困扰管理层的许多问题当中,高校课堂教学管理也是不敢忽视的一块。但是管理好高校课堂教学又面临很多麻烦需要解决,如何在工作琐碎,记录繁多的情况下将高校课堂教学的当前情况反…

webpack-编译原理

webpack 编译过程 文章目录 webpack 编译过程初始化编译创建 chunk构建所有依赖模块产生 chunk assets合并 chunk assets 输出总过程 webpack 的作用是将源代码编译(构建、打包)成最终代码。 整个过程大致分为三个步骤 初始化编译输出 初始化 此阶段&a…

宝塔安装完redis 如何访问

1,配置bind和密码 我前面在宝塔中安装完成redis,在我的电脑上访问。发现连接不上去。 2,手动杀死一次redis在重启 #执行一下命令 ps -ef | grep 6379 强制杀死进程 125117 是进程号 #杀死进程 kill -9 125117 3,重启redis 重启…

【Validator】自定义字段、结构体补充及自定义验证,go案例讲解ReportError和errors.As在其中的使用

自定义字段名称的显示 RegisterTagNameFunc,自定义字段名称的显示,以便于从字段标签(tag)中提取更有意义的名称。 代码示例:自定义字段名称 package mainimport ("fmt""reflect""strings&q…

在LINUX上安装英伟达CUDA Toolkit

下载安装包 wget https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda-repo-rhel8-12-8-local-12.8.0_570.86.10-1.x86_64.rpm 安装RPM包 sudo rpm -i cuda-repo-rhel8-12-8-local-12.8.0_570.86.10-1.x86_64.rpm sudo dnf clean all sudo dnf…

Git进阶之旅:Git 多人合作

项目克隆: git clone 仓库地址:把远程项目克隆到本地形成一个本地的仓库 克隆下来的仓库和远程仓库的名称一致 注意:git clone 远程仓库地址 远程仓库名:把远程仓库克隆下来,并自定义仓库名 多人协作: …

Flutter使用Flavor实现切换环境和多渠道打包

在Android开发中通常我们使用flavor进行多渠道打包,flutter开发中同样有这种方式,不过需要在原生中配置 具体方案其实flutter官网个了相关示例(https://docs.flutter.dev/deployment/flavors),我这里记录一下自己的操作 Android …

【2】阿里面试题整理

[1]. 说一下Java与C的区别。 Java和C是两种在软件开发领域应用非常广泛的语言,但它们的设计理念和应用场景有所不同。 Java是一种基于JVM的解释型语言,具有跨平台性,使用自动垃圾回收机制,这使得开发者可以更专注于业务逻辑&…