CVPR | CNN融合注意力机制,芜湖起飞!

devtools/2025/2/10 8:36:50/

**标题:**On the Integration of Self-Attention and Convolution
**论文链接:**https://arxiv.org/pdf/2111.14556
**代码链接:**https://github.com/LeapLabTHU/ACmix

创新点

1. 揭示卷积和自注意力的内在联系

文章通过重新分解卷积和自注意力模块的操作,发现它们在第一阶段(特征投影)都依赖于 1×1 卷积操作,并且这一阶段占据了大部分的计算复杂度(与通道数的平方成正比)。这一发现为整合两种模块提供了理论基础。

2. 提出 ACmix 模型

基于上述发现,作者提出了 ACmix 模型,它通过共享 1×1 卷积操作来同时实现卷积和自注意力的功能。具体来说:
**第一阶段:**输入特征通过 1×1 卷积投影,生成中间特征。
**第二阶段:**这些中间特征分别用于卷积路径(通过移位和聚合操作)和自注意力路径(计算注意力权重并聚合值)。最终,两条路径的输出通过可学习的权重加权求和,得到最终输出。

3. 改进的移位和聚合操作

文章还提出了一种改进的移位操作,通过使用 固定卷积核的分组卷积 来替代传统的张量移位操作。这种方法不仅提高了计算效率,还允许卷积核的可学习性,进一步增强了模型的灵活性。

4. 适应性路径权重

ACmix 引入了两个可学习的标量参数(α 和 β),用于动态调整卷积路径和自注意力路径的权重。这种设计不仅提高了模型的灵活性,还允许模型在不同深度上自适应地选择更适合的特征提取方式。实验表明,这种设计在模型的不同阶段表现出不同的偏好,例如在早期阶段更倾向于卷积,在后期阶段更倾向于自注意力。

整体结构

第一阶段:特征投影

在第一阶段,输入特征通过三个1×1卷积进行投影,分别生成查询(query)、键(key)和值(value)特征映射。这些特征映射随后被重塑为N块,形成一个包含3×N特征映射的中间特征集。

第二阶段:特征聚合

在第二阶段,中间特征集被分为两个路径进行处理:

  • **自注意力路径:**将中间特征集分为N组,每组包含三个特征映射(分别对应查询、键和值)。这些特征映射按照传统的多头自注意力机制进行处理,计算注意力权重并聚合值。
  • **卷积路径:**通过轻量级的全连接层生成k²个特征映射(k为卷积核大小)。这些特征映射通过移位和聚合操作,以类似传统卷积的方式处理输入特征,从局部感受野收集信息。

输出整合

最后,自注意力路径和卷积路径的输出通过两个可学习的标量参数(α和β)加权求和,得到最终的输出。

改进的移位和聚合操作

为了提高计算效率,ACmix模型采用了改进的移位操作,通过固定卷积核的分组卷积来替代传统的张量移位操作。这种方法不仅提高了计算效率,还允许卷积核的可学习性,进一步增强了模型的灵活性。

模型的灵活性和泛化能力

ACmix模型不仅适用于标准的自注意力机制,还可以与各种变体(如Patchwise Attention、Window Attention和Global Attention)结合使用。这种设计使得ACmix能够适应不同的任务需求,具有广泛的适用性。

消融实验

1. 结合两个路径的输出

消融实验探索了卷积和自注意力输出的不同组合方式对模型性能的影响。实验结果表明:

  • **卷积和自注意力的组合优于单一路径:**使用卷积和自注意力模块的组合始终优于仅使用单一路径(如仅卷积或仅自注意力)的模型。
  • **可学习参数的灵活性:**通过引入可学习的参数(如α和β)来动态调整卷积和自注意力路径的权重,ACmix能够根据网络中不同位置的需求自适应地调整路径强度,从而获得更高的灵活性和性能。

2. 组卷积核的选择

实验还对组卷积核的设计进行了验证,结果表明:

  • **用组卷积替代张量位移:**通过使用组卷积替代传统的张量位移操作,显著提高了模型的推理速度。
  • **可学习卷积核和初始化:**使用可学习的卷积核并结合精心设计的初始化方法,进一步增强了模型的灵活性,并有助于提升最终性能。

3. 不同路径的偏好

ACmix模型引入了两个可学习标量α和β,用于动态调整卷积和自注意力路径的权重。通过平行实验,观察到以下趋势:

  • **早期阶段偏好卷积:**在Transformer模型的早期阶段,卷积作为特征提取器表现更好。
  • **中间阶段混合使用:**在网络的中间阶段,模型倾向于混合使用两种路径,并逐渐增加对卷积的偏好。
  • **后期阶段偏好自注意力:**在网络的最后阶段,自注意力表现优于卷积。

4. 对模型性能的影响

这些消融实验结果表明,ACmix模型通过合理结合卷积和自注意力的优势,并优化计算路径,不仅在多个视觉任务上取得了显著的性能提升,还保持了较高的计算效率

ACmix模块的作用

1. 融合卷积和自注意力的优势

ACmix模块通过结合卷积的局部特征提取能力和自注意力的全局感知能力,实现了一种高效的特征融合策略。这种设计使得模型能够同时利用卷积的局部感受野特性和自注意力的灵活性。

2. 优化计算路径

ACmix通过优化计算路径和减少重复计算,提高了整体模块的计算效率。具体来说,它通过1×1卷积对输入特征图进行投影,生成中间特征,然后根据不同的范式(卷积和自注意力)分别重用和聚合这些中间特征。这种设计不仅减少了计算开销,还提升了模型性能。

3. 改进的位移与求和操作

在卷积路径中,ACmix采用深度可分离卷积(depthwise convolution)来替代低效的张量位移操作,从而提高了实际推理效率。

4. 动态调整路径权重

ACmix引入了两个可学习的标量参数(α和β),用于动态调整卷积和自注意力路径的权重。这种设计使得模型能够根据网络中不同位置的需求自适应地调整路径强度,从而获得更高的灵活性。

5. 广泛的应用潜力

ACmix在多个视觉任务(如图像分类、语义分割和目标检测)上均显示出优于单一机制(仅卷积或仅自注意力)的性能,展示了其广泛的应用潜力。

6. 实验验证

实验结果表明,ACmix在保持较低计算开销的同时,能够显著提升模型的性能。例如,在ImageNet分类任务中,ACmix模型在相同的FLOPs或参数数量下表现出色,并且在与竞争对手的基准比较中取得了持续的改进。此外,ACmix在ADE20K语义分割任务和COCO目标检测任务中也显示出明显的改进

代码实现

import torch
import torch.nn as nndef position(H, W, is_cuda=True):if is_cuda:loc_w = torch.linspace(-1.0, 1.0, W).cuda().unsqueeze(0).repeat(H, 1)loc_h = torch.linspace(-1.0, 1.0, H).cuda().unsqueeze(1).repeat(1, W)else:loc_w = torch.linspace(-1.0, 1.0, W).unsqueeze(0).repeat(H, 1)loc_h = torch.linspace(-1.0, 1.0, H).unsqueeze(1).repeat(1, W)loc = torch.cat([loc_w.unsqueeze(0), loc_h.unsqueeze(0)], 0).unsqueeze(0)return locdef stride(x, stride):b, c, h, w = x.shapereturn x[:, :, ::stride, ::stride]def init_rate_half(tensor):if tensor is not None:tensor.data.fill_(0.5)def init_rate_0(tensor):if tensor is not None:tensor.data.fill_(0.)class ACmix(nn.Module):def __init__(self, in_planes, out_planes, kernel_att=7, head=4, kernel_conv=3, stride=1, dilation=1):super(ACmix, self).__init__()self.in_planes = in_planesself.out_planes = out_planesself.head = headself.kernel_att = kernel_attself.kernel_conv = kernel_convself.stride = strideself.dilation = dilationself.rate1 = torch.nn.Parameter(torch.Tensor(1))self.rate2 = torch.nn.Parameter(torch.Tensor(1))self.head_dim = self.out_planes // self.headself.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1)self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1)self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=1)self.conv_p = nn.Conv2d(2, self.head_dim, kernel_size=1)self.padding_att = (self.dilation * (self.kernel_att - 1) + 1) // 2self.pad_att = torch.nn.ReflectionPad2d(self.padding_att)self.unfold = nn.Unfold(kernel_size=self.kernel_att, padding=0, stride=self.stride)self.softmax = torch.nn.Softmax(dim=1)self.fc = nn.Conv2d(3 * self.head, self.kernel_conv * self.kernel_conv, kernel_size=1, bias=False)self.dep_conv = nn.Conv2d(self.kernel_conv * self.kernel_conv * self.head_dim, out_planes,kernel_size=self.kernel_conv, bias=True, groups=self.head_dim, padding=1,stride=stride)self.reset_parameters()def reset_parameters(self):init_rate_half(self.rate1)init_rate_half(self.rate2)kernel = torch.zeros(self.kernel_conv * self.kernel_conv, self.kernel_conv, self.kernel_conv)for i in range(self.kernel_conv * self.kernel_conv):kernel[i, i // self.kernel_conv, i % self.kernel_conv] = 1.kernel = kernel.squeeze(0).repeat(self.out_planes, 1, 1, 1)self.dep_conv.weight = nn.Parameter(data=kernel, requires_grad=True)self.dep_conv.bias = init_rate_0(self.dep_conv.bias)def forward(self, x):q, k, v = self.conv1(x), self.conv2(x), self.conv3(x)scaling = float(self.head_dim) ** -0.5b, c, h, w = q.shapeh_out, w_out = h // self.stride, w // self.stride# ### att# ## positional encodingpe = self.conv_p(position(h, w, x.is_cuda))q_att = q.view(b * self.head, self.head_dim, h, w) * scalingk_att = k.view(b * self.head, self.head_dim, h, w)v_att = v.view(b * self.head, self.head_dim, h, w)if self.stride > 1:q_att = stride(q_att, self.stride)q_pe = stride(pe, self.stride)else:q_pe = peunfold_k = self.unfold(self.pad_att(k_att)).view(b * self.head, self.head_dim,self.kernel_att * self.kernel_att, h_out,w_out) # b*head, head_dim, k_att^2, h_out, w_outunfold_rpe = self.unfold(self.pad_att(pe)).view(1, self.head_dim, self.kernel_att * self.kernel_att, h_out,w_out) # 1, head_dim, k_att^2, h_out, w_outatt = (q_att.unsqueeze(2) * (unfold_k + q_pe.unsqueeze(2) - unfold_rpe)).sum(1) # (b*head, head_dim, 1, h_out, w_out) * (b*head, head_dim, k_att^2, h_out, w_out) -> (b*head, k_att^2, h_out, w_out)att = self.softmax(att)out_att = self.unfold(self.pad_att(v_att)).view(b * self.head, self.head_dim, self.kernel_att * self.kernel_att,h_out, w_out)out_att = (att.unsqueeze(1) * out_att).sum(2).view(b, self.out_planes, h_out, w_out)## convf_all = self.fc(torch.cat([q.view(b, self.head, self.head_dim, h * w), k.view(b, self.head, self.head_dim, h * w),v.view(b, self.head, self.head_dim, h * w)], 1))f_conv = f_all.permute(0, 2, 1, 3).reshape(x.shape[0], -1, x.shape[-2], x.shape[-1])out_conv = self.dep_conv(f_conv)return self.rate1 * out_att + self.rate2 * out_conv#输入 B C H W, 输出 B C H W
if __name__ == '__main__':block = ACmix(in_planes=64, out_planes=64)input = torch.rand(3, 64, 32, 32)output = block(input)print(input.size(), output.size())

http://www.ppmy.cn/devtools/157259.html

相关文章

SpringBoot3与MyBatis-Plus

4.1 介绍 MyBatis-Plus(简称 MP)是一个基于 MyBatis 的增强工具,提供通用 CRUD 操作、代码生成器、条件构造器、分页插件等功能,简化开发流程,提升效率。 4.2 特点 无侵入:只做增强不做修改,与…

树和二叉树_6

树和二叉树_6 一、leetcode-105二、题解1.引库2.代码 一、leetcode-105 从前序与中序遍历序列构造二叉树 给定两个整数数组 preorder 和 inorder ,其中 preorder 是二叉树的先序遍历, inorder 是同一棵树的中序遍历,请构造二叉树并返回其根节…

Postman接口测试:全局变量/接口关联/加密/解密

全局变量和环境变量 全局变量:在postman全局生效的变量,全局唯一 环境变量:在特定环境下生效的变量,本环境内唯一 设置: 全局变量: pm.globals.set("variable_key", "variable_value1&q…

maven详细讲解

学习目标 那什么是mavenmaven概念以及核心思想maven构建的生命周期、阶段以及目标maven仓库有哪些?maven依赖 那什么是maven?maven概念以及核心思想,maven构建的生命周期、阶段以及目标? 那什么是maven Maven是一个项目管理和构建…

gitlab个别服务无法启动可能原因

目录 一、gitlab的puma服务一直重启 1. 查看日志 2. 检查配置文件 3. 重新配置和重启 GitLab 4. 检查系统资源 5. 检查依赖和服务状态 6. 清理和优化 7. 升级 GitLab 8. 查看社区和文档 二、 gitlab个别服务无法启动可能原因 1.服务器内存或磁盘已满 2.puma端口冲突…

http cookie的作用学习

1.介绍 HTTP Cookie 是 服务器发送给客户端(浏览器)的一小段数据,它会被客户端存储,并在后续请求时自动携带,以便服务器识别用户、保持会话状态或存储用户偏好等信息。 流程: 服务器发送 Cookie 服务器…

基于SpringBoot+vue高效旅游管理系统

Spring Boot后端与Vue前端融合:构建高效旅游管理系统 目录 一、项目简介 二、开发技术与环境配置 2.1 SpringBoot框架 2.2 Java语言简介 2.3 Vue的介绍 2.4 mysql数据库介绍 2.5 B/S架构 三、系统功能实现 四、系统项目截图 登录页面 后台管理页面 用户…

SpringBoot3 + Jedis5 + Redis集群 如何通过scan方法分页获取所有keys

背景: 由于需要升级老项目代码,从SpringBoot1.5.x 升级到 SpringBoot3.3.x,框架中引用的Jedis自动升级到了 5.x;正好代码中有需要获取Redis集群的所有keys的需求存在;代码就不适用了,修改如下: POM 由于…