YOLO11改进|注意力机制篇|引入全局上下文注意力机制GCA

devtools/2024/10/11 6:33:26/

在这里插入图片描述

目录

    • 一、【】注意力机制
      • 1.1【GCA】注意力介绍
      • 1.2【GCA】核心代码
    • 二、添加【GCA】注意力机制
      • 2.1STEP1
      • 2.2STEP2
      • 2.3STEP3
      • 2.4STEP4
    • 三、yaml文件与运行
      • 3.1yaml文件
      • 3.2运行成功截图

一、【】注意力机制

1.1【GCA】注意力介绍

在这里插入图片描述

下图是【GCA】的结构图,让我们简单分析一下运行过程和优势

处理过程

  • 输入特征图
  • 输入的特征图大小为 𝐶×𝐻×𝑊,其中 𝐶 是通道数𝐻和 𝑊分别是特征图的高度和宽度。
  • Context Modeling(上下文建模):
  • 首先,输入特征图通过一个 1×1 卷积层(𝑊𝑘),这一步将每个像素点的特征压缩并降低维度,卷积输出的大小仍为 𝐶×𝐻×𝑊。
  • 之后,将特征图进行降维操作,应用 Softmax 函数,生成全局的上下文注意力图。这一步的结果是对特征图的全局加权,使得每个位置能够捕捉到整个图像范围内的全局上下文信息。注意力图应用于输入特征图,进行全局加权调整,得到全局上下文的增强特征图。
  • Transform(特征变换):
  • 第一步卷积操作(1×1 Conv):首先应用另一个 1×1 卷积层(𝑊1 )来重新映射全局上下文的特征,这步操作不会改变特征图的空间维度。
  • LayerNorm 和 ReLU:然后通过 Layer Normalization(层归一化)和 ReLU 激活函数来标准化并引入非线性变换,从而增强特征的表达能力。
  • 第二步卷积操作(1×1 Conv):接着,经过一个第二次的 1×1 卷积层(𝑊2)来进一步处理和压缩特征。
  • 残差连接:
  • 经过全局上下文增强的特征与原始输入特征通过 残差连接(Skip Connection) 进行叠加,这种操作保留了原始特征的同时,融入了全局上下文的增强信息,从而提高模型的鲁棒性和特征表达能力。
  • 输出特征图:
  • 输出特征图仍然是 𝐶×𝐻×𝑊,与输入保持相同的形状,但已经通过全局上下文增强和特征变换的处理。
    优势
  • 全局上下文捕捉:GC 模块的最大优势在于能够通过 Softmax 注意力机制对整个特征图进行全局建模。相比传统的局部卷积操作,该模块可以捕捉到图像全局的上下文信息,有助于识别跨越远距离的依赖关系,尤其是在处理复杂场景时,能够提高目标识别的精度。
  • 计算效率高:
    虽然该模块引入了注意力机制,但它通过 1×1 卷积来进行维度的压缩,降低了计算复杂度。相比标准的多头自注意力机制(如 Transformer),GC 模块在计算复杂度上更加友好,非常适合嵌入在卷积神经网络中。
  • 残差连接增强特征表达:
    残差连接使得原始特征得以保留,确保全局上下文信息的增强不会损失原有的特征表达能力。这种机制可以缓解梯度消失的问题,促进更深层次的网络结构训练。
  • 通用性强:
    由于 GC 模块使用的是标准的卷积和简单的注意力机制,它可以方便地嵌入到各种神经网络结构中,例如 ResNet、VGG 等,用于提升模型的全局信息感知能力。
  • 适应性好:
    该模块通过 Softmax 注意力机制学习全局上下文的加权方式,能够动态适应不同输入特征,从而使模型在不同场景下有更好的泛化能力和适应性。
    在这里插入图片描述

1.2【GCA】核心代码

import torch  # 导入 PyTorch
from torch import nn  # 从 PyTorch 导入神经网络模块# 定义 ContextBlock 类,继承自 nn.Module
class ContextBlock(nn.Module):def __init__(self, inplanes, ratio, pooling_type='att', fusion_types=('channel_add',)):super(ContextBlock, self).__init__()# 验证 fusion_types 是否有效valid_fusion_types = ['channel_add', 'channel_mul']# 检查 pooling_type 是否在有效选项 ['avg', 'att'] 中assert pooling_type in ['avg', 'att']# 确认 fusion_types 是列表或元组assert isinstance(fusion_types, (list, tuple))# 确认 fusion_types 中的所有元素都在 valid_fusion_types 中assert all([f in valid_fusion_types for f in fusion_types])# 确保至少有一个 fusion 类型被指定assert len(fusion_types) > 0, 'at least one fusion should be used'self.inplanes = inplanes  # 输入通道数self.ratio = ratio  # 缩减比例self.planes = int(inplanes * ratio)  # 缩减后的通道数self.pooling_type = pooling_type  # 池化类型('avg' 或 'att')self.fusion_types = fusion_types  # 融合类型# 如果池化类型为 'att',定义注意力池化的卷积层和 softmaxif pooling_type == 'att':self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)self.softmax = nn.Softmax(dim=2)  # 在维度 2 上做 softmaxelse:# 如果池化类型为 'avg',定义自适应平均池化层self.avg_pool = nn.AdaptiveAvgPool2d(1)# 定义 'channel_add' 融合类型的卷积层序列if 'channel_add' in fusion_types:self.channel_add_conv = nn.Sequential(nn.Conv2d(self.inplanes, self.planes, kernel_size=1),  # 1x1 卷积nn.LayerNorm([self.planes, 1, 1]),  # 层归一化nn.ReLU(inplace=True),  # 激活函数 ReLUnn.Conv2d(self.planes, self.inplanes, kernel_size=1)  # 1x1 卷积,恢复到原通道数)else:self.channel_add_conv = None# 定义 'channel_mul' 融合类型的卷积层序列if 'channel_mul' in fusion_types:self.channel_mul_conv = nn.Sequential(nn.Conv2d(self.inplanes, self.planes, kernel_size=1),  # 1x1 卷积nn.LayerNorm([self.planes, 1, 1]),  # 层归一化nn.ReLU(inplace=True),  # 激活函数 ReLUnn.Conv2d(self.planes, self.inplanes, kernel_size=1)  # 1x1 卷积,恢复到原通道数)else:self.channel_mul_conv = None# 定义空间池化方法def spatial_pool(self, x):batch, channel, height, width = x.size()  # 获取输入张量的形状if self.pooling_type == 'att':  # 如果池化类型为 'att'input_x = x.view(batch, channel, height * width)  # 展平 H 和 Winput_x = input_x.unsqueeze(1)  # 增加一个维度context_mask = self.conv_mask(x)  # 应用 1x1 卷积层context_mask = context_mask.view(batch, 1, height * width)  # 展平 H 和 Wcontext_mask = self.softmax(context_mask)  # 在 H * W 上应用 softmaxcontext_mask = context_mask.unsqueeze(-1)  # 增加一个维度context = torch.matmul(input_x, context_mask)  # 计算加权和context = context.view(batch, channel, 1, 1)  # 恢复形状else:context = self.avg_pool(x)  # 如果池化类型为 'avg',直接应用平均池化return context  # 返回上下文张量# 定义前向传播方法def forward(self, x):context = self.spatial_pool(x)  # 获取上下文信息out = x  # 初始化输出if self.channel_mul_conv is not None:channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))  # 应用通道乘法融合out = out * channel_mul_term  # 输出乘以融合结果if self.channel_add_conv is not None:channel_add_term = self.channel_add_conv(context)  # 应用通道加法融合out = out + channel_add_term  # 输出加上融合结果return out  # 返回最终输出# 测试代码块
if __name__ == "__main__":in_tensor = torch.ones((1, 64, 128, 128))  # 创建一个全1的输入张量,形状为 (1, 64, 128, 128)cb = ContextBlock(inplanes=64, ratio=0.25, pooling_type='att')  # 创建 ContextBlock 实例out_tensor = cb(in_tensor)  # 传递输入张量进行前向传播print(in_tensor.shape)  # 打印输入张量的形状print(out_tensor.shape)  # 打印输出张量的形状

二、添加【GCA】注意力机制

2.1STEP1

首先找到ultralytics/nn文件路径下新建一个Add-module的python文件包【这里注意一定是python文件包,新建后会自动生成_init_.py】,如果已经跟着我的教程建立过一次了可以省略此步骤,随后新建一个GCA.py文件并将上文中提到的注意力机制的代码全部粘贴到此文件中,如下图所示在这里插入图片描述

2.2STEP2

在STEP1中新建的_init_.py文件中导入增加改进模块的代码包如下图所示在这里插入图片描述

2.3STEP3

找到ultralytics/nn文件夹中的task.py文件,在其中按照下图添加在这里插入图片描述

2.4STEP4

定位到ultralytics/nn文件夹中的task.py文件中的def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)函数添加如图代码,【如果不好定位可以直接ctrl+f搜索定位】

在这里插入图片描述

三、yaml文件与运行

3.1yaml文件

以下是添加【GCA】注意力机制在Backbone中的yaml文件,大家可以注释自行调节,效果以自己的数据集结果为准

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'# [depth, width, max_channels]n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPss: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPsm: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPsl: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPsx: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs# YOLO11n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4- [-1, 2, C3k2, [256, False, 0.25]]- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8- [-1, 2, C3k2, [512, False, 0.25]]- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16- [-1, 2, C3k2, [512, True]]- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32- [-1, 2, C3k2, [1024, True]]- [-1, 1, SPPF, [1024, 5]] # 9- [-1, 2, C2PSA, [1024]] # 10# YOLO11n head
head:- [-1, 1, nn.Upsample, [None, 2, "nearest"]]- [[-1, 6], 1, Concat, [1]] # cat backbone P4- [-1, 2, C3k2, [512, False]] # 13- [-1, 1, nn.Upsample, [None, 2, "nearest"]]- [[-1, 4], 1, Concat, [1]] # cat backbone P3- [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)- [-1,1,ContextBlock,[256]]- [-1, 1, Conv, [256, 3, 2]]- [[-1, 13], 1, Concat, [1]] # cat head P4- [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 10], 1, Concat, [1]] # cat head P5- [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)- [[17, 20, 23], 1, Detect, [nc]] # Detect(P3, P4, P5)

以上添加位置仅供参考,具体添加位置以及模块效果以自己的数据集结果为准

3.2运行成功截图

在这里插入图片描述

OK 以上就是添加【GCA】注意力机制的全部过程了,后续将持续更新尽情期待

在这里插入图片描述


http://www.ppmy.cn/devtools/124030.html

相关文章

基于RAMS的台风苏拉(Saola)模拟预报深入分析引言

台风苏拉(Saola)是近年来对我国沿海地区造成严重影响的热带气旋之一。准确模拟和预报苏拉的路径和强度,对于防灾减灾具有重要意义。区域大气模拟系统(RAMS)作为一款功能强大的数值天气预报模型,能够提供精细…

云原生化 - 工具镜像(简约版)

在微服务和云原生环境中,容器化的目标之一是尽可能保持镜像小型化以提高启动速度和减少安全风险。然而,在实际操作中,有时候需要临时引入一些工具来进行调试、监控或问题排查。Kubernetes提供了临时容器(ephemeral containers&…

智能听诊器:宠物健康管理的革命

智能听诊器不仅仅是一个简单的监测工具,它代表了宠物健康管理的一次革命。通过收集和分析宠物的生理数据,智能听诊器能够帮助宠物主人和医生更好地理解宠物的健康需求,从而提供更加个性化的护理方案。 智能听诊器通过高精度的传感器&#xf…

方法重写与多态

方法重写 1.在子类和父类直接 2.方法名相同 3.参数个数和类型相同 4.返回类型相同或是其父类 5.访问权限不能严于父类 package com.hz.ch04.test01;public abstract class Pet {private String name;private int love;private int health;public String getName() {retur…

【iOS原生代码-音频播放】AVAudioPlayer 本地音频设置姊妹篇:如何将多个音频分别指定设置为左、右声道

AVAudioPlayer 本地音频设置姊妹篇:将多个音频分别指定设置为左、右声道 设备/引擎:Mac(11.6)/Mac Mini 开发工具:Xcode(15.0.1) 开发语言:Objective-c/c 开发需求:将…

软件编程课主要是学什么 讲解介绍

软件编程课程主要涵盖以下几个方面的内容: 编程语言: 学习一种或多种编程语言,如‌ Java、‌Python、‌C等。这些语言是编写代码 的基础,每种语言都有其特定的语法和规则, 需要遵循这些规则来创建可执行的程序。‌ …

Vscode+Pycharm+Vue.js+WEUI+django火锅(三)理解Vue

新创建的Vue项目里面很多文件,对于新手,老老实实做一下了解。 1.框架逻辑 框架的逻辑都是相通的,花点时间理一下就清晰了。 2.文件目录及文件 创建好的vue项目下,主要的文件和文件夹要先认识一下,并与框架逻辑对应起…

第五章 RabbitMQ之快速入门Spring AMQP开发

目录 一、AMQP与Sping AMQP 1.1. SpringAMQP常用类 1.2. SpringAMQP常用注解 1.3. 实际电商项目示例代码 二、 案例代码演示 2.1. 案例需求 2.2. 实现代码 2.2.1. 创建SpringBoot工程 2.2.2. 父工程pom依赖 2.2.3. 生产者pom依赖 2.2.3. 生产者配置文件 2.2.4. 生…