【Block总结】SCSA,探索空间与通道注意力之间的协同效应|即插即用

ops/2025/1/31 16:12:02/

论文信息

该论文于2025年1月27日发布,探讨了空间注意力和通道注意力的协同作用,提出了一种新的空间与通道协同注意力模块(SCSA)。该模块由可共享多语义空间注意力(SMSA)和渐进通道自注意力(PCSA)组成,旨在提升视觉任务中的特征提取能力。

  • 论文链接:https://arxiv.org/pdf/2407.05128

  • GitHub链接:https://github.com/HZAI-ZJNU/SCSA
    在这里插入图片描述

创新点

  • 多语义空间注意力(SMSA):整合多种语义信息,通过渐进压缩策略将空间先验信息注入通道自注意力中。
  • 渐进通道自注意力(PCSA):基于通道单头自注意力机制,增强特征交互,缓解多语义信息之间的差异。
  • 协同机制:通过空间注意力引导通道注意力的学习,提升模型的整体性能。

方法

SCSA的实现方法包括以下几个步骤:

  1. 特征分解:将输入特征分解为多个独立的子特征,以便高效提取多语义空间信息。
  2. 轻量级卷积:在每个子特征内应用不同大小的深度一维卷积,捕获不同的语义空间结构。
  3. 空间注意力图生成:通过组归一化处理不同的子特征,生成空间注意力图。
  4. 通道自注意力计算:利用渐进式压缩和单头自注意力机制,计算通道间的相似性并缓解语义差异。
    在这里插入图片描述

SCSA与其他注意力机制的具体改进

SCSA(Spatial and Channel Synergistic Attention)是一种新型的注意力机制,旨在结合空间注意力和通道注意力的优势,以提升深度学习模型在视觉任务中的表现。与传统的注意力机制相比,SCSA在多个方面进行了改进。

具体改进

  1. 多语义空间信息的利用

    • SCSA通过可共享的多语义空间注意力(SMSA)模块,充分利用了输入图像中的多语义空间信息。这一模块采用多尺度深度共享的1D卷积,能够捕捉到丰富的空间特征,从而增强局部和全局特征的表示能力[1][2]。
  2. 通道特征的精细化处理

    • SCSA中的渐进式通道自注意力(PCSA)模块,通过输入感知的自注意力机制,能够有效地精炼通道特征。这一机制不仅减轻了多语义信息之间的语义差异,还确保了通道特征的稳健整合,从而提升了模型的整体性能[1][2]。
  3. 协同效应的引入

    • SCSA通过将空间注意力和通道注意力模块并行组合,利用它们之间的协同效应。空间注意力帮助模型聚焦于重要的空间区域,而通道注意力则强调重要的特征通道。两者的结合使得模型能够同时关注最具信息量的空间位置和特征通道,从而实现更优的决策[1][2]。
  4. 性能提升

    • 在多个基准测试中,SCSA表现出色,超越了现有的最先进注意力机制。例如,在ImageNet-1K分类、MSCOCO目标检测和ADE20K分割任务中,SCSA均展示了显著的性能提升,尤其在低光照和小目标场景下的表现尤为突出[2][1]。
  5. 处理语义差异的能力

    • SCSA有效地处理了由于多语义信息引起的语义差异和交互问题。通过精细化的通道特征处理,SCSA能够更好地整合不同特征通道的信息,提升了模型在复杂场景下的泛化能力[2]。

SCSA通过整合空间和通道注意力的优势,显著提升了特征提取的能力,并在多个视觉任务中取得了优异的表现。其在多语义信息利用、通道特征精细化处理、协同效应引入及性能提升等方面的具体改进,使其在深度学习领域中成为一种具有潜力的注意力机制。

效果

实验结果表明,SCSA在多个视觉任务中表现优异,超越了现有的最先进注意力机制。具体效果包括:

  • 图像分类:在ImageNet-1K上,SCSA实现了最高的Top-1准确率。
  • 目标检测:在MSCOCO上,SCSA在不同检测器上均表现出色,尤其在小目标和低光照场景中。
  • 语义分割:在ADE20K上,SCSA显著提高了mIoU,证明了其在细粒度任务中的有效性。

实验结果

研究团队在七个基准数据集上进行了广泛的实验,包括:

  • 分类:ImageNet-1K
  • 目标检测:MSCOCO、Pascal VOC、VisDrone、ExDark
  • 分割:ADE20K、MSCOCO

实验结果显示,SCSA在各个任务中均优于其他即插即用的注意力机制,展现出强大的泛化能力。

总结

SCSA模块通过有效整合空间和通道注意力的优势,显著提升了特征提取的能力,并在多个视觉任务中取得了优异的表现。该研究为未来的深度学习模型设计提供了新的思路,尤其是在处理复杂视觉任务时,SCSA的引入可能会成为一种重要的工具。

代码

import typing as timport torch
import torch.nn as nn
from einops import rearrange
__all__ = ['SCSA']class SCSA(nn.Module):def __init__(self,dim: int,head_num: int,window_size: int = 7,group_kernel_sizes: t.List[int] = [3, 5, 7, 9],qkv_bias: bool = False,fuse_bn: bool = False,down_sample_mode: str = 'avg_pool',attn_drop_ratio: float = 0.,gate_layer: str = 'sigmoid',):super(SCSA, self).__init__()self.dim = dimself.head_num = head_numself.head_dim = dim // head_numself.scaler = self.head_dim ** -0.5self.group_kernel_sizes = group_kernel_sizesself.window_size = window_sizeself.qkv_bias = qkv_biasself.fuse_bn = fuse_bnself.down_sample_mode = down_sample_modeassert self.dim // 4, 'The dimension of input feature should be divisible by 4.'self.group_chans = group_chans = self.dim // 4self.local_dwc = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[0],padding=group_kernel_sizes[0] // 2, groups=group_chans)self.global_dwc_s = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[1],padding=group_kernel_sizes[1] // 2, groups=group_chans)self.global_dwc_m = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[2],padding=group_kernel_sizes[2] // 2, groups=group_chans)self.global_dwc_l = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[3],padding=group_kernel_sizes[3] // 2, groups=group_chans)self.sa_gate = nn.Softmax(dim=2) if gate_layer == 'softmax' else nn.Sigmoid()self.norm_h = nn.GroupNorm(4, dim)self.norm_w = nn.GroupNorm(4, dim)self.conv_d = nn.Identity()self.norm = nn.GroupNorm(1, dim)self.q = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)self.k = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)self.v = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)self.attn_drop = nn.Dropout(attn_drop_ratio)self.ca_gate = nn.Softmax(dim=1) if gate_layer == 'softmax' else nn.Sigmoid()if window_size == -1:self.down_func = nn.AdaptiveAvgPool2d((1, 1))else:if down_sample_mode == 'recombination':self.down_func = self.space_to_chans# dimensionality reductionself.conv_d = nn.Conv2d(in_channels=dim * window_size ** 2, out_channels=dim, kernel_size=1, bias=False)elif down_sample_mode == 'avg_pool':self.down_func = nn.AvgPool2d(kernel_size=(window_size, window_size), stride=window_size)elif down_sample_mode == 'max_pool':self.down_func = nn.MaxPool2d(kernel_size=(window_size, window_size), stride=window_size)def forward(self, x: torch.Tensor) -> torch.Tensor:"""The dim of x is (B, C, H, W)"""# Spatial attention priority calculationb, c, h_, w_ = x.size()# (B, C, H)x_h = x.mean(dim=3)l_x_h, g_x_h_s, g_x_h_m, g_x_h_l = torch.split(x_h, self.group_chans, dim=1)# (B, C, W)x_w = x.mean(dim=2)l_x_w, g_x_w_s, g_x_w_m, g_x_w_l = torch.split(x_w, self.group_chans, dim=1)x_h_attn = self.sa_gate(self.norm_h(torch.cat((self.local_dwc(l_x_h),self.global_dwc_s(g_x_h_s),self.global_dwc_m(g_x_h_m),self.global_dwc_l(g_x_h_l),), dim=1)))x_h_attn = x_h_attn.view(b, c, h_, 1)x_w_attn = self.sa_gate(self.norm_w(torch.cat((self.local_dwc(l_x_w),self.global_dwc_s(g_x_w_s),self.global_dwc_m(g_x_w_m),self.global_dwc_l(g_x_w_l)), dim=1)))x_w_attn = x_w_attn.view(b, c, 1, w_)x = x * x_h_attn * x_w_attn# Channel attention based on self attention# reduce calculationsy = self.down_func(x)y = self.conv_d(y)_, _, h_, w_ = y.size()# normalization first, then reshape -> (B, H, W, C) -> (B, C, H * W) and generate q, k and vy = self.norm(y)q = self.q(y)k = self.k(y)v = self.v(y)# (B, C, H, W) -> (B, head_num, head_dim, N)q = rearrange(q, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),head_dim=int(self.head_dim))k = rearrange(k, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),head_dim=int(self.head_dim))v = rearrange(v, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),head_dim=int(self.head_dim))# (B, head_num, head_dim, head_dim)attn = q @ k.transpose(-2, -1) * self.scalerattn = self.attn_drop(attn.softmax(dim=-1))# (B, head_num, head_dim, N)attn = attn @ v# (B, C, H_, W_)attn = rearrange(attn, 'b head_num head_dim (h w) -> b (head_num head_dim) h w', h=int(h_), w=int(w_))# (B, C, 1, 1)attn = attn.mean((2, 3), keepdim=True)attn = self.ca_gate(attn)return attn * xif __name__ == "__main__":# 如果GPU可用,将模块移动到 GPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 输入张量 (batch_size, height, width,channels)x = torch.randn(1,32,40,40).to(device)# 初始化 HWD 模块dim=32block = SCSA(dim=32, head_num=8, window_size=7)print(block)block = block.to(device)# 前向传播output = block(x)print("输入:", x.shape)print("输出:", output.shape)

输出结果:
在这里插入图片描述


http://www.ppmy.cn/ops/154519.html

相关文章

Python 函数魔法书:基础、范例、避坑、测验与项目实战

Python 函数魔法书:基础、范例、避坑、测验与项目实战 内容简介 本系列文章是为 Python3 学习者精心设计的一套全面、实用的学习指南,旨在帮助读者从基础入门到项目实战,全面提升编程能力。文章结构由 5 个版块组成,内容层层递进…

【计算机视觉】目标跟踪应用

一、简介 目标跟踪是指根据目标物体在视频当前帧图像中的位置,估计其在下一帧图像中的位置。视频帧由t到t1的检测,虽然也可以使用目标检测获取,但实际应用中往往是不可行的,原因如下: 目标跟踪的目的是根据目标在当前…

Java 性能优化与新特性

Java学习资料 Java学习资料 Java学习资料 一、引言 Java 作为一门广泛应用于企业级开发、移动应用、大数据等多个领域的编程语言,其性能和特性一直是开发者关注的重点。随着软件系统的规模和复杂度不断增加,对 Java 程序性能的要求也越来越高。同时&a…

sprnigboot集成Memcached

安装Memcached 下载地址 32位系统 1.2.5版本:http://static.jyshare.com/download/memcached-1.2.5-win32-bin.zip 32位系统 1.2.6版本:http://static.jyshare.com/download/memcached-1.2.6-win32-bin.zip 32位系统 1.4.4版本:http://stati…

MongoDB中常用的几种高可用技术方案及优缺点

MongoDB 的高可用性方案主要依赖于其内置的 副本集 (Replica Set) 和 Sharding 机制。下面是一些常见的高可用性技术方案: 1. 副本集 (Replica Set) 副本集是 MongoDB 提供的主要高可用性解决方案,确保数据在多个节点之间的冗余存储和自动故障恢复。副…

基于Docker搭建Sentinel Dashboard

从官网下载sentinel jar文件在与sentinel-dashboard-1.8.8.jar同一目录创建Dockerfile文件构建docker镜像文件创建镜像tag包提交镜像至镜像仓库下面就可以部署sentinel-dashboard容器了验证sentinel-dashboard控制台是否可用Sentinel 是一个开源的分布式流量控制与熔断框架,由…

【腾讯云】腾讯云docker搭建单机hadoop

这里写目录标题 下载jdk hadoop修改hadoop配置编写Dockerfile构建镜像运行镜像创建客户端 下载jdk hadoop wget --no-check-certificate https://repo.huaweicloud.com/java/jdk/8u151-b12/jdk-8u151-linux-x64.tar.gz wget --no-check-certificate https://repo.huaweicloud.…

物联网智能项目之——智能家居项目的实现!

成长路上不孤单😊😊😊😊😊😊 【14后😊///计算机爱好者😊///持续分享所学😊///如有需要欢迎收藏转发///😊】 今日分享关于物联网智能项目之——智能家居项目…