YOLOV5改进系列(2)——CA注意力机制

embedded/2024/12/23 3:03:13/

一、CA注意力机制

1.1 CA注意力机制介绍

Coordinate Attention(CA)是一种用于移动网络的轻量级注意力机制,旨在在不增加计算成本的前提下增强特征表达能力。与传统的通道注意力机制(如Squeeze-and-Excitation)不同,CA在通道注意力中嵌入了位置信息,从而能更好地在空间上进行关注。

传统通道注意力的不足

常见的注意力机制如SENet(Squeeze-and-Excitation Network)通过2D全局池化提取全局信息,仅关注通道间关系,忽视了重要的位置信息。BAM和CBAM等机制在尝试结合位置信息,但多采用卷积操作,无法有效建模长距离依赖。

CA注意力的创新

CA通过将2D通道注意力分解为两个1D编码过程,分别沿垂直和水平方向聚合特征。这样一来,模型不仅能捕获长距离的空间依赖关系,还能保留细粒度的位置信息,帮助模型在需要定位物体结构的任务中表现更优。

CA的结构和工作流程

CA模块主要由两部分组成:

  1. 坐标信息嵌入:通过1D池化操作分别在水平方向和垂直方向上聚合特征,生成两个包含位置信息的方向感知特征图。
  2. 生成坐标注意力:将这两个方向特征图结合,生成两个方向上的注意力权重,分别作用于输入特征图的水平方向和垂直方向,最终得到增强的输出特征。

1.2 CA注意力机制和CBAM和SE对比

这张图展示了CA(Coordinate Attention)注意力机制与传统的SE(Squeeze-and-Excitation)通道注意力机制和CBAM(Convolutional Block Attention Module)机制的结构对比:

  1. SE注意力机制 (图a)

    • SE模块通过全局平均池化(Global Average Pool)将输入特征图(C×H×W)压缩为通道向量(C×1×1),丢失了空间信息。
    • 该通道向量经过两个全连接层和非线性激活后生成权重,然后通过sigmoid函数调整权重值。
    • 最终,该权重用于对输入特征图进行通道维度的加权(re-weight),实现通道间注意力的增强。
  2. CBAM注意力机制 (图b)

    • CBAM在通道注意力的基础上加入了空间注意力。
    • 首先,CBAM使用全局平均池化和全局最大池化(GAP + GMP)生成通道注意力权重。
    • 然后通过一个7×7卷积提取空间信息,得到空间注意力图。
    • 最后,通道注意力和空间注意力共同作用于输入特征图,实现更精确的特征加权。
  3. CA(Coordinate Attention)注意力机制 (图c)

    • CA通过1D池化操作将输入特征图分别在水平方向和垂直方向上进行池化(X Avg Pool和Y Avg Pool),生成两个方向上的特征图(C×1×W和C×H×1),保留了空间位置信息。
    • 然后,这两个特征图被连接(concat)后,通过卷积操作提取特征,进一步生成两个方向上的注意力权重。
    • 最终,CA将这两个权重分别应用在输入特征图的水平方向和垂直方向,实现方向感知和位置敏感的注意力增强。

关键点

这张图突出了CA与SE和CBAM的不同之处:CA保留了空间位置信息,并通过两个1D池化操作沿不同方向获取长距离依赖,而传统的SE和CBAM由于2D全局池化或卷积操作,会导致部分空间信息丢失。

1.3 Coordinate Attention Blocks

坐标注意力机制通过两步操作编码了通道关系和具有精确位置信息的长程依赖性:坐标信息嵌入和坐标注意力生成。接下来,我们将详细描述这一过程。

3.2.1 坐标信息嵌入

全局池化通常用于通道注意力机制中以全局编码空间信息,但这种操作将全局的空间信息压缩成一个通道描述符,因此难以保留位置信息。而在视觉任务中,位置信息对捕捉空间结构至关重要。为了促使注意力模块在空间上捕获具有精确位置信息的长程交互,我们将全局池化分解为一对一维(1D)特征编码操作,如公式(1)所示。具体来说,给定输入 X,我们分别使用两个不同空间范围的池化核 (H, 1) 和 (1, W),来分别沿水平和垂直坐标方向编码每个通道。这样,对于第c个通道在高度h处的输出可以表示为:

同样地,第c个通道在宽度w处的输出可以表示为:

上述两种变换分别在两个空间方向上聚合特征,生成一对具有方向感知的特征图。这与通道注意力方法中的压缩操作不同,后者只生成一个特征向量。这两种变换使我们的注意力模块能够在一个空间方向上捕捉长程依赖,并在另一个空间方向上保留精确的位置信息,这有助于网络更准确地定位感兴趣的对象。

通俗的讲,你可以把坐标注意力想象成一个“导航系统”。普通的注意力机制就像你看着一幅地图,但不太清楚具体的路在哪里。而坐标注意力不仅能让你看到地图,还能准确告诉你要去的目的地在地图的哪个横坐标(东西方向)和纵坐标(南北方向)。这样一来,模型就能更准确地找到图像中的“目标”,比如对象的位置或重要的细节。

1. 为何要用坐标信息?

在之前的一些方法中,模型往往只能处理整体的图像信息,而没有清楚地知道图像中某些物体的准确位置。这就像我们在看地图时,只有地图上所有地方的概览,但不知道特定地点在哪。坐标注意力帮助模型不仅知道图片中的内容,还能明确它们的位置。

2. 具体如何操作?

首先,我们把图像中的信息分成两部分来看待——横向(左右方向)纵向(上下方向)。可以想象这就像在读表格,既可以从行来看每一排的内容,也可以从列来看每一列的内容。然后,为图像中的每个像素(可以理解为一个小点)计算它在行和列上的信息。这样做的好处是,模型可以同时关注图片中物体的横向和纵向位置信息。

3. 如何生成注意力?

将横向和纵向的信息组合起来,然后用一个简单的数学公式处理。这一步类似于把所有的细节信息放进一个“计算器”进行处理,生成一个新的“注意力权重”。这些“权重”告诉模型哪些位置更重要,哪些位置不需要那么多关注。最后,模型根据这些权重来调整它对图像的处理方式,突出重点区域。

4. 为什么有效?

这样处理之后,模型不再是简单地处理整个图像,而是可以“聪明地”找出图片中重要的位置,并特别关注这些区域。比如在一张人脸的照片里,模型可以自动关注到眼睛、鼻子、嘴巴这些重要的部分,从而更好地识别图像。

二、PyTorch实现

1.在yolov5/models/common.py最下面添加代码

class h_sigmoid(nn.Module):def __init__(self, inplace=True):super(h_sigmoid, self).__init__()self.relu = nn.ReLU6(inplace=inplace)def forward(self, x):return self.relu(x + 3) / 6class h_swish(nn.Module):def __init__(self, inplace=True):super(h_swish, self).__init__()self.sigmoid = h_sigmoid(inplace=inplace)def forward(self, x):return x * self.sigmoid(x)class CoordinateAttention(nn.Module):def __init__(self, inc, outc, reduction=32):super(CoordinateAttention, self).__init__()self.pool_h = nn.AdaptiveAvgPool2d((None,1)) #横向池化self.pool_w = nn.AdaptiveMaxPool2d((1,None)) #纵向池化#Reduction层 减小特征图的通道数reduced_channel = max(8,inc//reduction)self.conv1 = nn.Conv2d(inc,reduced_channel,kernel_size=1,stride=1,padding=0)self.bn1 = nn.BatchNorm2d(reduced_channel)self.act = h_swish()#分别在水平和垂直方向卷积self.conv_h = nn.Conv2d(reduced_channel,outc,kernel_size=1,stride=1,padding=0)self.conv_w = nn.Conv2d(reduced_channel, outc, kernel_size=1, stride=1, padding=0)self.sigmoid = nn.Sigmoid()def forward(self, x):identity = x #保留输入特征图作为最后的输出n,c,h,w = x.size()x_h = self.pool_h(x) # N*C*H*1x_w = self.pool_w(x).permute(0,1,3,2) # N*C*1*W -> N*C*W*1# 拼接后进行卷积和激活y = torch.cat([x_h,x_w],dim=2)  # N*C*(H+W)*1y = self.conv1(y)y = self.bn1(y)y = self.act(y)# 分别提取横向和纵向特征x_h,x_w = torch.split(y,[h,w],dim=2)x_w = x_w.permute(0,1,3,2)# 生成横向和纵向注意力权重a_h = self.conv_h(x_h).sigmoid()a_w = self.conv_w(x_w).sigmoid()# 输出特征图与注意力权重相乘out = identity*a_w*a_hreturn out

2.在yolov5/models/yolo.py的parse_model添加代码

3.在yolov5/models/新建yolov5s_CA.yaml文件

4.修改yolo.py调用

三、验证是否可以使用


http://www.ppmy.cn/embedded/124952.html

相关文章

前端学习第一天笔记 HTML5 CSS初学以及VSCODE中的常用快捷键

前端学习笔记 VsCode常用快捷键列表HTML5标题标签标签之段落、换行、水平线标签之图片图片路径详解标签之超文本链接标签之文本列表标签之有序列表列表标签之无序列表标签之表格表格之合并单元格Form表单表单元素文本框 密码框 块元素与行内元素(内联元素&#xff0…

Spring Boot 集成 Flowable UI 实现请假流程 Demo

​ 博客主页: 南来_北往 系列专栏:Spring Boot实战 在现代企业应用中,工作流管理是一个至关重要的部分。通过使用Spring Boot和Flowable,可以方便地构建和管理工作流。本文将详细介绍如何在Spring Boot项目中集成Flowable UI&#xff0c…

正态分布的极大似然估计一个示例,详细展开的方程求解步骤

此示例是 什么是极大似然估计 中的一个例子,本文的目的是给出更加详细的方程求解步骤,便于数学基础不好的同学理解。 目标 假设我们有一组样本数据 x 1 , x 2 , … , x n x_1, x_2, \dots, x_n x1​,x2​,…,xn​,它们来自一个正态分布 N…

前端的全栈混合之路Meteor篇:分布式数据协议DDP深度剖析

本文属于进阶篇,并不是太适合新人阅读,但纯粹的学习还是可以的,因为后续会实现很多个ddp的版本用于web端、nodejs端、安卓端和ios端,提前预习和复习下。ddp协议是一个C/S架构的协议,但是客户端也同时可以是服务端。 什…

【stm32】寄存器(stm32技术手册下载链接)

1、资料下载 RM0008_STM32F101xx,STM32F102xx,STM32F103xx,STM32F105xx和STM32F107xx单片机参考手册 | STMCU中文官网 2、代码 设置PB7 //设置PB7 #define SDA_IN() {GPIOB->CRL&0X0FFFFFFF;GPIOB->CRL|(u32)8<<28;} #define SDA_OUT() {GPIOB->…

C语言期中自测试卷

选择题 1、若有变量定义int a; double b; 要输入数据存放在a和b中&#xff0c;则下面正确的输入数据的语句为&#xff1a; 【 正确答案: C】 A. scanf("%d%f",a,b); B. scanf("%d%f",&a,&b); C. scanf("%d%lf",&a,&b); D. scan…

四、Python基础语法(数据类型转换)

数据类型转换就是将一种类型的数据转换为另外一种类型的数据&#xff0c;数据类型转换不会改变原数据&#xff0c;是产生一个新的数据。 变量 要转换为的类型(原数据) -> num int(28) 一.int()将其他类型转换为整型 1.整数类型的字符串转换为整型 num1 28 print(type…

spring揭秘25-springmvc05-过滤器与拦截器区别(补充)

文章目录 【README】【1】springmvc拦截器回顾【1.1】定义与应用【1.2】拦截器作用范围 【2】servlet过滤器回顾【2.1】过滤器定义与应用【2.2】过滤器作用范围 【3】springmvc拦截器与servlet过滤器区别&#xff08;重要*&#xff09;【3.1】拦截方法调用代码实现 【README】 …