【Block总结】MAB,多尺度注意力块|即插即用

server/2025/2/3 6:17:37/

文章目录

  • 一、论文信息
  • 二、创新点
  • 三、方法
    • MAB模块解读
      • 1、MAB模块概述
      • 2、MAB模块组成
      • 3、MAB模块的优势
  • 四、效果
  • 五、实验结果
  • 六、总结
  • 代码

一、论文信息

  • 标题: Multi-scale Attention Network for Single Image Super-Resolution
  • 作者: Yan Wang, Yusen Li, Gang Wang, Xiaoguang Liu
  • 机构: 南开大学
  • 发表会议: CVPR 2024 Workshops
  • 论文链接: arXiv
  • GitHub代码库: icandle/MAN
    在这里插入图片描述

二、创新点

  • 多尺度大核注意力(MLKA): 结合了多尺度机制与大核卷积,能够有效捕捉不同尺度的信息,避免了传统方法中常见的“块状伪影”问题。

  • 门控空间注意力单元(GSAU): 通过引入门控机制,优化了空间注意力的计算,去除了不必要的线性层,从而提高了信息聚合的效率和准确性。

  • 灵活的网络结构: 通过堆叠不同数量的MLKA和GSAU模块,构建出多种复杂度的网络,以实现性能与计算量之间的平衡。

三、方法

  1. 网络架构: Multi-scale Attention Network (MAN)由三个主要模块组成:

    • 浅层特征提取模块(SF): 负责初步的特征提取。
    • 深层特征提取模块(DF): 基于多个多尺度注意力块(MAB),进一步提取丰富的特征。
    • 高质量图像重建模块: 将提取的特征用于最终的图像重建。
  2. 多尺度注意力块(MAB):

    • MLKA模块: 结合大核注意力、多个尺度机制和门控聚合,建立不同尺度之间的相关性。
    • GSAU模块: 整合空间注意力和门控机制,简化前馈网络。
  3. MLKA的功能:

    • 大核注意力: 通过分解卷积建立长距离关系。
    • 多尺度机制: 增强固定LKA以学习全尺度信息的注意力图。
    • 门控聚合: 动态调整注意力图以避免伪影。

MAB模块解读

在这里插入图片描述

1、MAB模块概述

MAB(Multi-scale Attention Block)是Multi-scale Attention Network (MAN)中的核心组件,旨在通过结合多尺度大核注意力(MLKA)和门控空间注意力单元(GSAU)来提升图像超分辨率的性能。MAB模块的设计旨在有效捕捉图像中的局部和全局特征,同时避免传统卷积网络中常见的“块状伪影”问题。

2、MAB模块组成

MAB模块主要由以下两个部分构成:

  1. 多尺度大核注意力(MLKA):

    • 功能: MLKA通过引入多尺度机制,结合大核卷积,能够在不同尺度上提取丰富的特征信息。
    • 结构:
      • 首先,MLKA使用点卷积(Point-wise convolution)调整通道数。
      • 然后,将特征分成三组,每组使用不同尺寸的大核卷积(如7×7、21×21、35×35)进行处理,膨胀率分别设置为(2,3,4)。
      • 为了避免膨胀卷积带来的“块状伪影”,MLKA引入了门控聚合机制,通过逐元素乘法将深度卷积的输出与对应组的LKA输出相结合,从而动态调整注意力图的输出。
  2. 门控空间注意力单元(GSAU):

    • 功能: GSAU旨在增强特征表示能力,通过结合空间注意力和门控机制,优化信息聚合过程。
    • 结构:
      • GSAU通常由两个分支组成,其中一个分支使用深度卷积对特征进行加权,另一个分支则通过空间自注意力机制捕捉空间上下文信息。
      • 这种设计减少了不必要的线性层,降低了计算复杂度,同时增强了特征的表达能力。

3、MAB模块的优势

  • 多尺度特征提取: 通过MLKA,MAB能够在多个尺度上提取特征,增强了模型对不同图像细节的敏感性。

  • 减少伪影: 通过门控聚合机制,MAB有效地减少了由于膨胀卷积引起的块状伪影,提升了图像重建的质量。

  • 高效的特征表示: GSAU的引入使得模型能够更好地聚合空间信息,提升了特征的表达能力,进而提高了超分辨率的效果。

四、效果

  • 性能提升: 实验结果表明,MAN在多个超分辨率基准测试中表现优异,能够与当前最先进的模型(如SwinIR)相媲美,同时在计算效率上也有显著改善。

  • 避免伪影: 通过MLKA和GSAU的结合,模型有效减少了图像重建中的伪影现象,提升了视觉效果。

五、实验结果

  • 基准测试: 论文中使用了多个数据集(如Set5、Set14、BSD100、Urban100、Manga109)进行测试,结果显示MAN在PSNR和SSIM指标上均优于传统的超分辨率模型,尤其是在高倍数放大(如×4)时表现突出。
数据集上采样因子MAN PSNR (dB)与SwinIR比较
Set5×238.42相当
×334.91略低
×432.87良好
Set14×234.44相近
×330.92略低
×429.09良好
BSD100×232.53相近
×329.65略低
×427.90良好
Urban100×233.80相近
×334.45略低
×433.73良好
Manga109×240.02相近
×335.21略低
×431.22良好

六、总结

Multi-scale Attention Network (MAN)通过结合多尺度大核注意力和门控空间注意力机制,成功提升了单幅图像超分辨率重建的性能和效率。该研究不仅解决了传统方法中的一些局限性,还为未来的超分辨率模型设计提供了新的思路。MAN在多个基准测试中的优异表现,证明了其在实际应用中的潜力,尤其是在需要高质量图像重建的场景中。

代码

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as Fclass LayerNorm(nn.Module):r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.The ordering of the dimensions in the inputs. channels_last corresponds to inputs withshape (batch_size, height, width, channels) while channels_first corresponds to inputswith shape (batch_size, channels, height, width)."""def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):super().__init__()self.weight = nn.Parameter(torch.ones(normalized_shape))self.bias = nn.Parameter(torch.zeros(normalized_shape))self.eps = epsself.data_format = data_formatif self.data_format not in ["channels_last", "channels_first"]:raise NotImplementedErrorself.normalized_shape = (normalized_shape,)def forward(self, x):if self.data_format == "channels_last":return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)elif self.data_format == "channels_first":u = x.mean(1, keepdim=True)s = (x - u).pow(2).mean(1, keepdim=True)x = (x - u) / torch.sqrt(s + self.eps)x = self.weight[:, None, None] * x + self.bias[:, None, None]return xclass SGAB(nn.Module):def __init__(self, n_feats, drop=0.0, k=2, squeeze_factor=15, attn='GLKA'):super().__init__()i_feats = n_feats * 2self.Conv1 = nn.Conv2d(n_feats, i_feats, 1, 1, 0)self.DWConv1 = nn.Conv2d(n_feats, n_feats, 7, 1, 7 // 2, groups=n_feats)self.Conv2 = nn.Conv2d(n_feats, n_feats, 1, 1, 0)self.norm = LayerNorm(n_feats, data_format='channels_first')self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)def forward(self, x):shortcut = x.clone()# Ghost Expandx = self.Conv1(self.norm(x))a, x = torch.chunk(x, 2, dim=1)x = x * self.DWConv1(a)x = self.Conv2(x)return x * self.scale + shortcutclass GroupGLKA(nn.Module):def __init__(self, n_feats, k=2, squeeze_factor=15):super().__init__()i_feats = 2 * n_featsself.n_feats = n_featsself.i_feats = i_featsself.norm = LayerNorm(n_feats, data_format='channels_first')self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)# Multiscale Large Kernel Attentionself.LKA7 = nn.Sequential(nn.Conv2d(n_feats // 3, n_feats // 3, 7, 1, 7 // 2, groups=n_feats // 3),nn.Conv2d(n_feats // 3, n_feats // 3, 9, stride=1, padding=(9 // 2) * 4, groups=n_feats // 3, dilation=4),nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0))self.LKA5 = nn.Sequential(nn.Conv2d(n_feats // 3, n_feats // 3, 5, 1, 5 // 2, groups=n_feats // 3),nn.Conv2d(n_feats // 3, n_feats // 3, 7, stride=1, padding=(7 // 2) * 3, groups=n_feats // 3, dilation=3),nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0))self.LKA3 = nn.Sequential(nn.Conv2d(n_feats // 3, n_feats // 3, 3, 1, 1, groups=n_feats // 3),nn.Conv2d(n_feats // 3, n_feats // 3, 5, stride=1, padding=(5 // 2) * 2, groups=n_feats // 3, dilation=2),nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0))self.X3 = nn.Conv2d(n_feats // 3, n_feats // 3, 3, 1, 1, groups=n_feats // 3)self.X5 = nn.Conv2d(n_feats // 3, n_feats // 3, 5, 1, 5 // 2, groups=n_feats // 3)self.X7 = nn.Conv2d(n_feats // 3, n_feats // 3, 7, 1, 7 // 2, groups=n_feats // 3)self.proj_first = nn.Sequential(nn.Conv2d(n_feats, i_feats, 1, 1, 0))self.proj_last = nn.Sequential(nn.Conv2d(n_feats, n_feats, 1, 1, 0))def forward(self, x, pre_attn=None, RAA=None):shortcut = x.clone()x = self.norm(x)x = self.proj_first(x)a, x = torch.chunk(x, 2, dim=1)a_1, a_2, a_3 = torch.chunk(a, 3, dim=1)a = torch.cat([self.LKA3(a_1) * self.X3(a_1), self.LKA5(a_2) * self.X5(a_2), self.LKA7(a_3) * self.X7(a_3)],dim=1)x = self.proj_last(x * a) * self.scale + shortcutreturn x# MABclass MAB(nn.Module):def __init__(self, n_feats):super().__init__()self.LKA = GroupGLKA(n_feats)self.LFE = SGAB(n_feats)def forward(self, x, pre_attn=None, RAA=None):# large kernel attentionx = self.LKA(x)# local feature extractionx = self.LFE(x)return xif __name__ == "__main__":dim=96 # 通道要被3整除# 如果GPU可用,将模块移动到 GPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 输入张量 (batch_size, channels,height, width)x = torch.randn(2,dim,40,40).to(device)# 初始化 MAB模块block = MAB(dim)print(block)block = block.to(device)# 前向传播output = block(x)print("输入:", x.shape)print("输出:", output.shape)

http://www.ppmy.cn/server/164538.html

相关文章

项目开发实践——基于SpringBoot+Vue3实现的在线考试系统(九)(完结篇)

文章目录 一、成绩查询模块实现1、学生成绩查询功能实现1.1 页面设计1.2 前端页面实现1.3 后端功能实现2、成绩分段查询功能实现2.1 页面设计2.2 前端页面实现2.3 后端功能实现二、试卷练习模块实现三、我的分数模块实现1、 页面设计2、 前端页面实现3、 后端功能实现四、交流区…

飞桨PaddleNLP套件中使用DeepSeek r1大模型

安装飞桨PaddleNLP 首先安装最新的PaddleNLP3.0版本: pip install paddlenlp3.0.0b3 依赖库比较多,可能需要较长时间安装。 安装好后,看看版本: import paddlenlp paddlenlp.__version__ 输出: 3.0.0b3.post2025…

AI大模型开发原理篇-4:神经概率语言模型NPLM

神经概率语言模型(NPLM)概述 神经概率语言模型(Neural Probabilistic Language Model, NPLM) 是一种基于神经网络的语言建模方法,它将传统的语言模型和神经网络结合在一起,能够更好地捕捉语言中的复杂规律…

Spring Boot项目如何使用MyBatis实现分页查询

写在前面:大家好!我是晴空๓。如果博客中有不足或者的错误的地方欢迎在评论区或者私信我指正,感谢大家的不吝赐教。我的唯一博客更新地址是:https://ac-fun.blog.csdn.net/。非常感谢大家的支持。一起加油,冲鸭&#x…

第六篇:事务与并发控制

第六篇:事务与并发控制 目标读者: 本篇文章适合中级数据库学习者,特别是那些希望理解数据库事务管理与并发控制机制的开发者或数据库管理员。通过掌握事务的原理与控制方法,你将能够设计高效且可靠的数据库应用,确保…

Vue 3 中的响应式系统:ref 与 reactive 的对比与应用

Vue 3 的响应式系统是其核心特性之一,它允许开发者以声明式的方式构建用户界面。Vue 3 引入了两种主要的响应式 API:ref 和 reactive。本文将详细介绍这两种 API 的用法、区别以及在修改对象属性和修改整个对象时的不同表现,并提供完整的代码…

mysql重学(一)mysql语句执行流程

思考 一条查询语句如何执行?mysql语句中若列不存在,则在哪个阶段报错一条更新语句如何执行?redolog和binlog的区别?为什么要引入WAL什么是Changbuf?如何工作写缓冲一定好吗?什么情况会引发刷脏页删除语句会…

【Linux指令/信号总结】粘滞位 重定向 系统调用 信号产生 信号处理

文章目录 1.>2. cat3.系统命令bash和shell和kernel权限只被认证一次粘滞位引入前提知识场景解释为什么普通用户(无w权限)可以删除文件?为什么普通用户通过sudo设置文件权限为000后仍能删除文件? 结论 粘滞位是干什么的&#xf…