如何在YoloV8中添加注意力机制(两种方式)

ops/2024/9/19 0:53:40/ 标签: YOLO, opencv, 目标检测

文章目录

    • 概要
    • 添加注意力机制流程
    • #添加方式一:将注意力机制添加到额外的一层
      • 添加方式二:将注意力机制添加到其中一层,不引入额外的层

概要

提示:这里可以添加技术概要

例如:

openAI 的 GPT 大模型的发展历程。

添加注意力机制流程

#添加方式一:将注意力机制添加到额外的一层

首先找一份注意力机制的代码,比如:ParNetAttention


import numpy as np
import torch
from torch import nn
from torch.nn import initclass ParNetAttention(nn.Module):def __init__(self, channel=512):super().__init__()self.sse = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(channel, channel, kernel_size=1),nn.Sigmoid())self.conv1x1 = nn.Sequential(nn.Conv2d(channel, channel, kernel_size=1),nn.BatchNorm2d(channel))self.conv3x3 = nn.Sequential(nn.Conv2d(channel, channel, kernel_size=3, padding=1),nn.BatchNorm2d(channel))self.silu = nn.SiLU()def forward(self, x):b, c, _, _ = x.size()x1 = self.conv1x1(x)x2 = self.conv3x3(x)x3 = self.sse(x) * xy = self.silu(x1 + x2 + x3)return y

在ultralytics\nn\modules\下新建一份attention.py文件,将注意力机制代码放进去。
注意力机制
打开ultralytics\nn\tasks.py文件首先引入刚才新建的注意力机制代码:

from ultralytics.nn.modules.attention import ParNetAttention

如果注意力机制代码是需要输入通道数的,那么在parse_model方法中加上这行代码:

# 有通道数的注意力机制放在这elif m is (Zoom_cat,SSFF,ParNetAttention):c2 = ch[f]args = [c2, *args]

如果注意力机制代码是不需要输入通道数的,可以不加这个。
最后更改yaml文件,将这个注意力机制加在你想加的地方
比如我加在SPPF层后边:

backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4- [-1, 3, C2f_attention, [128, True]]- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8- [-1, 6, C2f_attention, [256, True]]- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16- [-1, 6, C2f_attention, [512, True]]- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32- [-1, 3, C2f_attention, [1024, True]]- [-1, 1, SPPF, [1024, 5]] # 9- [-1, 1, ParNetAttention,[]] # 10

加完之后这层注意力机制就是第10层,在Head中>=10层的全部+1
比如:

  • [[15, 18, 21], 1, WorldDetect, [nc, 512, True]] # Detect(P3, P4, P5)
  • 就要变成
  • [[16, 19, 22], 1, WorldDetect, [nc, 512, True]] # Detect(P3, P4, P5)

添加方式二:将注意力机制添加到其中一层,不引入额外的层

比如想将注意力机制加到c2f中,打开ultralytics\nn\modules\block.py
首先将注意力机制代码导入到block.py中,复制一份c2f代码:

class C2f(nn.Module):"""Faster Implementation of CSP Bottleneck with 2 convolutions."""def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):"""Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,expansion."""super().__init__()self.c = int(c2 * e)  # hidden channelsself.cv1 = Conv(c1, 2 * self.c, 1, 1)self.cv2 = Conv((2 + n) * self.c, c2, 1)  # optional act=FReLU(c2)self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))def forward(self, x):"""Forward pass through C2f layer."""y = list(self.cv1(x).chunk(2, 1))y.extend(m(y[-1]) for m in self.m)return self.cv2(torch.cat(y, 1))def forward_split(self, x):"""Forward pass using split() instead of chunk()."""y = list(self.cv1(x).split((self.c, self.c), 1))y.extend(m(y[-1]) for m in self.m)return self.cv2(torch.cat(y, 1))

重命名为:

class C2f_attention(nn.Module):"""Faster Implementation of CSP Bottleneck with 2 convolutions."""def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):"""Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,expansion."""super().__init__()self.c = int(c2 * e)  # hidden channelsself.cv1 = Conv(c1, 2 * self.c, 1, 1)self.cv2 = Conv((2 + n) * self.c, c2, 1)  # optional act=FReLU(c2)self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))self.attention_AdditiveBlock = AdditiveBlock(c2)def forward(self, x):"""Forward pass through C2f layer."""y = list(self.cv1(x).chunk(2, 1))y.extend(m(y[-1]) for m in self.m)return self.attention_AdditiveBlock(self.cv2(torch.cat(y, 1)))def forward_split(self, x):"""Forward pass using split() instead of chunk()."""y = list(self.cv1(x).split((self.c, self.c), 1))y.extend(m(y[-1]) for m in self.m)return self.cv2(torch.cat(y, 1))

然后需要在__init___函数中声明注意力机制函数,比如:self.attention_AdditiveBlock = AdditiveBlock(c2) 这里注意力机制如果需要参数的话就写上一层的输出通道数,作为本层的输入通道数,这里我想将注意力机制添加在cv2层后边,那么我的参数就是cv2等的输出通道数也就是c2,
在forward中对哪一层使用注意力机制就可以放在哪一层,比如

return self.attention_AdditiveBlock(self.cv2(torch.cat(y, 1)))

注意力机制就添加完成了,最后需要把新的C2f_attention注册一下,
步骤如下:
首先在block.py中引入C2f_attention

__all__ = ("DFL",   "HGBlock",    "HGStem",    "SPP","SPPF",    "C1",    "C2",    "C3","C2f",    "C2fAttn",    "ImagePoolingAttn","ContrastiveHead",    "BNContrastiveHead","C3x",    "C3TR",    "C3Ghost","GhostBottleneck",    "Bottleneck","BottleneckCSP",    "Proto",    "RepC3",    "ResNetLayer","RepNCSPELAN4",    "ELAN1",    "ADown",    "AConv","SPPELAN",    "CBFuse",    "CBLinear",    "RepVGGDW","CIB",    "C2fCIB",    "Attention",    "PSA",    "SCDown",# --------------------------------添加注意力机制"C2f_attention",
)

然后再__init__.py中添加

from .block import (C1,    C2,    C3,    C3TR,    CIB,    DFL,ELAN1,    PSA,    SPP,    SPPELAN,    SPPF,AConv,    ADown,    Attention,    BNContrastiveHead,Bottleneck,    BottleneckCSP,    C2f,C2fAttn,    C2fCIB,    C3Ghost,    C3x,CBFuse,    CBLinear,    ContrastiveHead,GhostBottleneck,    HGBlock,HGStem,    ImagePoolingAttn,Proto,    RepC3,    RepNCSPELAN4,RepVGGDW,    ResNetLayer,    SCDown,# ---------------------添加注意力机制-------------C2f_attention,
)

以及__all__ = ("C2f_attention")都要添加。
最后在task.py中引入C2f_attention:
三个位置:
(1)from ultralytics.nn.modules import (C2f_attention)
(2)if m in { Classify, Conv, ConvTranspose, GhostConv,.........C2f_attention
(3)if m in {BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3, C2fCIB,C2f_attention}: args.insert(2, n) # number of repeats n = 1

最后在yaml文件中将C2F层替换为C2f_attention即可。

backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4- [-1, 3, C2f_attention, [128, True]]- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8- [-1, 6, C2f_attention, [256, True]]- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16- [-1, 6, C2f_attention, [512, True]]- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32- [-1, 3, C2f_attention, [1024, True]]- [-1, 1, SPPF, [1024, 5]] # 9

http://www.ppmy.cn/ops/109946.html

相关文章

【FFMPEG】FFplay音视频同步分析(下)

audio_decode_frame函数分析 首先说明一下,audio_decode_frame() 函数跟解码毫无关系,真正的解码函数是 decoder_decode_frame 。 audio_decode_frame() 函数的主要作用是从 FrameQueue 队列里面读取 AVFrame ,然后把 is->audio_buf 指向…

Debian 12 中为 root 用户修改最大打开文件数进程数的限制

在 Debian 12 中,管理和配置打开文件的限制涉及到系统级别和用户级别的设置。以下是详细的步骤来修改和管理“打开文件”限制: 1. 查看当前的限制 首先,了解当前的限制配置: 系统级别: cat /proc/sys/fs/file-max这…

可测试,可维护,可移植:上位机软件分层设计的重要性

互联网中,软件工程师岗位会分前端工程师,后端工程师。这是由于互联网软件规模庞大,从业人员众多。前后端分别根据各自需求发展不一样的技术栈。那么上位机软件呢?它规模小,通常一个人就能开发一个项目。它还有必要分前…

移动订货小程序哪个好 批发订货系统源码哪个好

订货小程序就是依托微信小程序的订货系统,微信小程序订货系统相较于其他终端的订货方式,能够更快进入商城,对经销商而言更为方便。今天,我们一起盘点三个主流的移动订货小程序,看看哪个移动订货小程序好。 第一、核货宝…

unocss 一直热更新打印[vite] hot updated: /__uno.css

报错信息 "unocss 一直热更新打印 [vite] hot updated: /__uno.css" 表示你的项目正在使用 unocss 这个库,并且它正在不断地进行热更新。vite 是一个现代化的前端构建工具,这条信息实际上是 vite 在通知你有关于 __uno.css 文件的热更新发生了…

【2025】基于python的网上商城比价系统、智能商城比价系统、电商比价系统、智能商城比价系统(源码+文档+解答)

博主介绍: ✌我是阿龙,一名专注于Java技术领域的程序员,全网拥有10W粉丝。作为CSDN特邀作者、博客专家、新星计划导师,我在计算机毕业设计开发方面积累了丰富的经验。同时,我也是掘金、华为云、阿里云、InfoQ等平台…

discuz论坛3.4 截图粘贴图片发帖后显示不正常问题

处理方法 source\function 路径下修改function_discuzcode.php function bbcodeurl($url, $tags) 函数 if(!in_array(strtolower(substr($url, 0, 6)), array(http:/, https:, ftp://, rtsp:/, mms://,data:i) 这一句里增加 data:i 即可 function bbcodeurl($url,…

vue2 组件通信

props emits props:用于接收父组件传递给子组件的数据。可以定义期望从父组件接收的数据结构和类型。‘子组件不可更改该数据’emits:用于定义组件可以向父组件发出的事件。这允许父组件监听子组件的事件并作出响应。(比如数据更新) props检查属性 属性名类型描述默认值typ…

9.12日常记录

1.extern关键字 1)诞生动机:在一个C语言项目中,需要再多个文件中使用同一全局变量或是函数,那么就需要在这些文件中再声明一遍 2)用于声明在其他地方定义的一个变量或是函数,在当前位置只是声明,告诉编译器…

Windows10 如何设置电脑ip

1、首先打开控制面板 或者使用WinR 输入control 找到网络和Internet 点击网络和共享中心 点击更改适配器设置 找到你要需要设置的网络,右键 如果你的网口特别多,不确定是哪一个,拔插一下看看哪个以太网的标志是断开状态就可以了 点击属性…

【HarmonyOS】云开发-云数据库(二)

背景 书接上回,实现了云侧和端侧的云数据库创建、更新、修改等操作。这篇文章实现调用云函数对云数据库进行增删改查。 CloudProgram 项目配置 新建函数 在cloudfunctions目录下点击右键,选择新建Cloud Function,输入query-student-functi…

零宽字符应用场景及前端解决方案

零宽字符(Zero Width Characters)是一类在文本中不可见但具有特定功能的特殊字符。称为零宽字符,也叫幽灵字符。它们在显示时不占据任何空间,但在文本处理和显示中发挥着重要作用。这些字符主要包括零宽度空格、零宽度非连接符、零…

LLM - 理解 多模态大语言模型 (MLLM) 的预训练与相关技术 (三)

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://spike.blog.csdn.net/article/details/142063880 免责声明:本文来源于个人知识与公开资料,仅用于学术交流,欢迎讨论,不支持转载。 完备(F…

C# WPF上位机与西门子PLC通信实现实例解析

1. 使用第三方库(如S7.Net或Sharp7) 代码示例: // 使用S7.Net库与PLC建立连接 var plc new S7.Net.Plc(CpuType.S71500, "192.168.1.10", 0, 1); plc.Open();// 读取PLC中的DB块 byte[] buffer new byte[256]; plc.Read("DB…

OpenGL3.3_C++_Windows(37)

调试: 视觉错误与CPU调试不同,在GLSL代码中也不能设置断点,出现错误的时候寻找错误的源头可能会非常困难。 glGetError() GLenum glGetError();返回整形数字,查询错误标记,但是当一个错误标记…

币安/欧易合约对冲APP系统开发

币安合约对冲系统开发是一个复杂且专业的过程,它涉及到多个方面的技术和策略。以下是对币安合约对冲系统开发的一个概述: 一、系统概述 币安合约对冲系统是一种利用币安交易所提供的合约交易功能,通过同时建立多头和空头仓位来减少或消除市…

iotdbtool助力时序数据库IoTDB高效运维

iotdbtool 项目简介 iotdbtool 是一个使用 Go 语言编写的命令行工具,基于 Kubernetes 环境,提供了 IoTDB 数据的备份功能。它可以从 Kubernetes 集群中的 IoTDB Pod 中提取数据,并将其上传到阿里云 OSS 存储桶中。 iotdbtool 支持 iotDB 单…

SiC,GaN驱动优选驱动方案SiLM5350系列SiLM5350MDDCM-DG 带米勒钳位Clamp保护功能 单通道隔离栅极驱动器

SiLM5350MDDCM-DG是一款适用于IGBT、MOSFET的单通道 隔离门极驱动器,具有10A拉电流和10A灌电流驱动能 力。提供内部钳位功能,可单独控制 上升时间和下降时间。 在 SOP8 封 装 中 具 有 3000VRMS 隔 离 耐 压 ( 符 合 UL1577)。 与…

跨平台集成:在 AI、微服务和 Azure 云之间实现无缝工作流

跨平台集成在现代 IT 架构中的重要性 随着数字化转型的不断加速,对集成各种技术平台的需求也在快速增长。在当今的数字世界中,组织在复杂的环境中执行运营,其中多种技术需要无缝协作。环境的复杂性可能取决于业务的性质和组织提供的服务。具…

从JVM角度深入理解String

从JVM角度深入理解String String是Java中的一个类,是一个引用类型,用于表示字符串。它是不可变的(immutable),即一旦创建,其值就不能被修改。任何对String对象的修改操作都会创建一个新的String对象&#x…