【Block总结】Shuffle Attention,新型的Shuffle注意力|即插即用

server/2025/2/5 5:09:32/

一、论文信息

  • 标题: SA-Net: Shuffle Attention for Deep Convolutional Neural Networks

  • 论文链接: arXiv

  • 代码链接: GitHub
    在这里插入图片描述

二、创新点

Shuffle Attention(SA)模块的主要创新在于高效结合了通道注意力和空间注意力,同时通过通道重排技术降低计算复杂度。具体创新点包括:

  • 通道分组: 将输入特征图的通道维度分成多个组,允许并行处理。
  • 通道重排: 通过打乱通道顺序,增强模型对通道特征的表达能力。
  • 融合注意力机制: 同时计算通道和空间注意力,提升特征表示的能力。
    在这里插入图片描述

三、方法

Shuffle Attention模块的实现步骤如下:

  1. 特征分组: 将输入特征图的通道数通过参数G进行分组,得到多个子特征图。

  2. 通道注意力:

    • 对每组特征使用全局平均池化(GAP),生成通道级统计数据。
    • 通过学习的权重和偏置调整每组特征的通道重要性,并使用Sigmoid激活函数应用于特征图。
  3. 空间注意力:

    • 使用组归一化(GroupNorm)计算每组特征的空间注意力。
    • 经过Sigmoid激活后,将空间注意力应用于特征图的空间维度。
  4. 通道重排: 在计算完通道和空间注意力后,使用通道重排操作打乱通道顺序,以增强特征表达能力。

  5. 输出: 返回经过重排后的输出特征图。

Shuffle Attention与传统注意力机制的优势比较

Shuffle Attention(SA)是一种新型的注意力机制,旨在提高深度卷积神经网络的性能。与传统的注意力机制相比,Shuffle Attention在多个方面展现出显著的优势,尤其是在计算效率和模型性能方面。

优势如下:

  1. 高效的计算性能
    Shuffle Attention通过将输入特征图的通道维度分成多个组,并对每个组进行并行处理,从而显著降低了计算复杂度。传统的注意力机制通常需要在全通道上进行计算,导致计算量大且效率低下。SA的设计使得在保持性能的同时,减少了参数量和计算量。例如,在ResNet50上,SA的参数量从300M降至25.56M,计算量从4.12 GFLOPs降至2.76 GFLOPs[2]。

  2. 融合通道和空间注意力
    Shuffle Attention同时结合了通道注意力和空间注意力,能够更全面地捕捉特征之间的依赖关系。传统的注意力机制往往是将这两种注意力机制分开处理,未能充分利用它们之间的相互关系。SA通过“通道重排”操作,促进了不同组之间的信息交流,从而提升了特征的表达能力和模型的整体性能。

  3. 增强特征表达能力
    Shuffle Attention通过通道重排和特征分组的方式,增强了模型对特征的表达能力。传统注意力机制在处理特征时,可能会忽略某些重要的通道或空间信息,而SA通过并行处理和重排,确保了所有特征都能得到充分利用,从而提高了模型的准确性。

  4. 适应性强
    Shuffle Attention的设计使其能够灵活适应不同的网络架构和任务需求。它可以作为一个轻量级的模块,方便地集成到现有的深度学习模型中,而不需要对整个网络结构进行大幅修改。这种灵活性使得SA在各种计算机视觉任务中表现出色,包括图像分类、目标检测和实例分割等。

四、效果

Shuffle Attention模块在多个计算机视觉任务中表现出色,尤其是在图像分类和目标检测任务中。通过有效的特征提取和信息融合,SA模块能够在保持较低计算复杂度的同时,显著提高模型的性能。

五、实验结果

在ImageNet-1k、MS COCO等数据集上的实验结果表明:

  • 准确率提升: SA模块在与ResNet等主干网络结合时,Top-1准确率提升超过1.34%。
  • 计算复杂度降低: 相比于传统的注意力机制,SA模块在参数量和计算量上均显著减少,例如在ResNet50上,参数量从300M降至25.56M,计算量从4.12 GFLOPs降至2.76 GFLOPs。

六、总结

Shuffle Attention模块通过创新的特征分组和通道重排机制,有效地结合了通道和空间注意力,显著提升了深度卷积神经网络的性能。该模块不仅在多个基准测试中表现优异,还展示了在实际应用中的潜力,尤其是在资源受限的环境中。未来的研究可以进一步探索SA模块在更复杂任务中的应用效果。

代码

代码有错误,我做了修改。代码如下:


import torch
import torch.nn.functional
import torch.nn.functional as F
from torch import nn
from torch.nn.parameter import Parameterclass sa_layer(nn.Module):"""Constructs a Channel Spatial Group module.Args:k_size: Adaptive selection of kernel size"""def __init__(self, channel, groups=64):super(sa_layer, self).__init__()self.groups = groupsself.avg_pool = nn.AdaptiveAvgPool2d(1)self.cweight = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))self.cbias = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))self.sweight = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))self.sbias = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))self.sigmoid = nn.Sigmoid()self.gn = nn.GroupNorm(channel // (2 * groups), channel // (2 * groups))@staticmethoddef channel_shuffle(x, groups):b, c, h, w = x.shapex = x.reshape(b, groups, -1, h, w)x = x.permute(0, 2, 1, 3, 4)# flattenx = x.reshape(b, -1, h, w)return xdef forward(self, x):b, c, h, w = x.shapex = x.reshape(b * self.groups, -1, h, w)x_0, x_1 = x.chunk(2, dim=1)# channel attentionxn = self.avg_pool(x_0)xn = self.cweight * xn + self.cbiasxn = x_0 * self.sigmoid(xn)print(xn)# spatial attentionxs = self.gn(x_1)xs = self.sweight * xs + self.sbiasxs = x_1 * self.sigmoid(xs)# concatenate along channel axisout = torch.cat([xn, xs], dim=1)out = out.reshape(b, -1, h, w)out = self.channel_shuffle(out, 2)return outif __name__ == "__main__":dim=64# 如果GPU可用,将模块移动到 GPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 输入张量 (batch_size, height, width,channels)x = torch.randn(2,dim,40,40).to(device)# 初始化 sa_layer模块block = sa_layer(dim,8)print(block)block = block.to(device)# 前向传播output = block(x)print("输入:", x.shape)print("输出:", output.shape)

输出结果:
在这里插入图片描述


http://www.ppmy.cn/server/165068.html

相关文章

在Windows下安装Ollama并体验DeepSeek r1大模型

在Windows下安装Ollama并体验DeepSeek r1大模型 Ollama在Windows下安装 Ollama官网:Ollama GitHub 下载Windows版Ollama软件:Release v0.5.7 ollama/ollama GitHub 下载ollama-windows-amd64.zip这个文件即可。可以说Windows拥抱开源真好&#xf…

本地部署DeepSeek开源多模态大模型Janus-Pro-7B实操

本地部署DeepSeek开源多模态大模型Janus-Pro-7B实操 Janus-Pro-7B介绍 Janus-Pro-7B 是由 DeepSeek 开发的多模态 AI 模型,它在理解和生成方面取得了显著的进步。这意味着它不仅可以处理文本,还可以处理图像等其他模态的信息。 模型主要特点:Permalink…

2021版小程序开发5——小程序项目开发实践(1)

2021版小程序开发5——小程序项目开发实践(1) 学习笔记 2025 使用uni-app开发一个电商项目; Hbuidler 首选uni-app官方推荐工具:https://www.dcloud.io/hbuilderx.htmlhttps://dev.dcloud.net.cn/pages/app/list 微信小程序 管理后台:htt…

GESP2023年12月认证C++六级( 第三部分编程题(1)闯关游戏)

参考程序代码&#xff1a; #include <cstdio> #include <cstdlib> #include <cstring> #include <algorithm> #include <string> #include <map> #include <iostream> #include <cmath> using namespace std;const int N 10…

React+Cesium基础教程(003):加载3D建筑物和创建标签

文章目录 03-加载3D建筑物和标签方式一方式二完整代码03-加载3D建筑物和标签 方式一 添加来自 OpenStreetMap 的建筑物模型,让场景更加丰富和真实: viewer.scene.primitives.add(new Cesium.createOsmBuildings() );方式二 使用 Cesium ion 资源:

android java系统弹窗的基础模板

1、资源文件 app\src\main\res\layout下增加custom_pop_layout.xml 定义弹窗的控件资源。 <?xml version"1.0" encoding"utf-8"?> <androidx.constraintlayout.widget.ConstraintLayout xmlns:android"http://schemas.android.com/apk/…

发布 VectorTraits v3.1(支持 .NET 9.0,支持 原生AOT)

文章目录 发布 VectorTraits v3.1&#xff08;支持 .NET 9.0&#xff0c;支持 原生AOT&#xff09;支持 .NET 9.0中断性变更 支持 原生AOT原生AOT的范例使用IlcInstructionSet参数 TraitsOutput类增加IsDynamicCodeCompiled/IsDynamicCodeSupported信息的输出为了支持原生AOT, …

详解Kafka并行计算架构

引言 在高流量的复杂场景下&#xff0c;Kafka 凭借卓越的性能表现脱颖而出&#xff0c;始终维持着极高的吞吐率和高效的消息消费能力&#xff0c;在众多消息队列产品中独树一帜。其稳定且强大的性能&#xff0c;不仅保障了海量数据的快速处理&#xff0c;还为各类业务的高效运行…