YOLO11改进|注意力机制篇|引入局部注意力HaloAttention

server/2024/10/11 3:25:11/

在这里插入图片描述

目录

    • 一、【HaloAttention】注意力机制
      • 1.1【HaloAttention】注意力介绍
      • 1.2【HaloAttention】核心代码
    • 二、添加【HaloAttention】注意力机制
      • 2.1STEP1
      • 2.2STEP2
      • 2.3STEP3
      • 2.4STEP4
    • 三、yaml文件与运行
      • 3.1yaml文件
      • 3.2运行成功截图

一、【HaloAttention】注意力机制

1.1【HaloAttention】注意力介绍

在这里插入图片描述

下图是【HaloAttention】的结构图,让我们简单分析一下运行过程和优势

处理过程

  • 图像分块:

  • 输入图像大小为 4×4×𝑐,其中 𝑐
    是通道数。该图像首先被分割为多个小块(如图所示被分为 4 个 2×2×𝑐的小块),每个块称为一个“block”。

  • Haloing 操作:

  • 在图像分块后,使用 haloing 操作扩展每个小块的边界。图中显示的是一个 halo 值为 1 的情况,即每个小块在其原有区域上扩展了 1 个像素的边界,形成了带有额外边界信息的邻域窗口。这一操作目的是为了在计算注意力时捕获块与块之间的上下文信息。

  • 邻域窗口计算:

  • Haloing 之后,每个小块拥有邻近区域的信息,即在扩展后的邻域窗口中包含了来自周围小块的部分信息。图中显示了每个小块及其周围邻域的窗口(如红色小块与其邻域的相关部分)。

  • 查询与注意力机制:

  • 在邻域窗口中应用 注意力机制。以每个小块作为查询(Query),与其扩展后的邻域窗口进行注意力计算,从中提取重要的上下文特征。注意力机制的引入使得每个小块不仅能够学习到自身的特征,还能从周围的块中获取相关的上下文信息,从而增强特征表达。

  • 输出:

  • 通过注意力机制的加权输出每个小块的结果,形成新的特征图。输出的特征图大小仍然是分块前的大小,但每个块内的特征已经经过上下文增强和融合。
    优势

  • 降低计算复杂度:

  • 通过将图像分割成小块并只在局部区域内应用注意力机制,减少了全局自注意力带来的高计算开销。这种方法可以大幅度降低计算复杂度,特别适合处理高分辨率图像或大规模数据集。

  • 局部上下文捕获:

  • Haloing 操作的引入使得每个块在计算注意力时能够感知到其邻域的上下文信息,克服了仅依赖自身区域的局限性。因此,它能够更好地捕捉局部细节和相关性,特别是在需要高精度定位的任务中(如图像分割或检测任务)。

  • 有效的特征增强:

  • 通过分块后的注意力机制,模型可以集中计算各个小块的注意力权重,并在局部范围内提升特征表达能力。这样可以避免全局注意力在大图像上计算时引入的冗余信息,同时仍能保证特征的有效整合。

  • 灵活性强:

  • 该方法可广泛应用于图像分类、目标检测、语义分割等任务中,并且可以根据实际需求调整分块大小和 halo 值,灵活适应不同的计算资源和任务要求。在这里插入图片描述

1.2【HaloAttention】核心代码

import torch
from torch import nn, einsum
import torch.nn.functional as Ffrom einops import rearrange, repeatdef to(x):return {"device": x.device, "dtype": x.dtype}def pair(x):return (x, x) if not isinstance(x, tuple) else xdef expand_dim(t, dim, k):t = t.unsqueeze(dim=dim)expand_shape = [-1] * len(t.shape)expand_shape[dim] = kreturn t.expand(*expand_shape)def rel_to_abs(x):b, l, m = x.shaper = (m + 1) // 2col_pad = torch.zeros((b, l, 1), **to(x))x = torch.cat((x, col_pad), dim=2)flat_x = rearrange(x, "b l c -> b (l c)")flat_pad = torch.zeros((b, m - l), **to(x))flat_x_padded = torch.cat((flat_x, flat_pad), dim=1)final_x = flat_x_padded.reshape(b, l + 1, m)final_x = final_x[:, :l, -r:]return final_xdef relative_logits_1d(q, rel_k):b, h, w, _ = q.shaper = (rel_k.shape[0] + 1) // 2logits = einsum("b x y d, r d -> b x y r", q, rel_k)logits = rearrange(logits, "b x y r -> (b x) y r")logits = rel_to_abs(logits)logits = logits.reshape(b, h, w, r)logits = expand_dim(logits, dim=2, k=r)return logitsclass RelPosEmb(nn.Module):def __init__(self, block_size, rel_size, dim_head):super().__init__()height = width = rel_sizescale = dim_head**-0.5self.block_size = block_sizeself.rel_height = nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale)self.rel_width = nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale)def forward(self, q):block = self.block_sizeq = rearrange(q, "b (x y) c -> b x y c", x=block)rel_logits_w = relative_logits_1d(q, self.rel_width)rel_logits_w = rearrange(rel_logits_w, "b x i y j-> b (x y) (i j)")q = rearrange(q, "b x y d -> b y x d")rel_logits_h = relative_logits_1d(q, self.rel_height)rel_logits_h = rearrange(rel_logits_h, "b x i y j -> b (y x) (j i)")return rel_logits_w + rel_logits_hclass HaloAttention(nn.Module):def __init__(self, dim, block_size, halo_size, dim_head=64, heads=8):super().__init__()assert halo_size > 0, "halo size must be greater than 0"self.dim = dimself.heads = headsself.scale = dim_head**-0.5self.block_size = block_sizeself.halo_size = halo_sizeinner_dim = dim_head * headsself.rel_pos_emb = RelPosEmb(block_size=block_size,rel_size=block_size + (halo_size * 2),dim_head=dim_head,)self.to_q = nn.Linear(dim, inner_dim, bias=False)self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)self.to_out = nn.Linear(inner_dim, dim)def forward(self, x):b, c, h, w, block, halo, heads, device = (*x.shape,self.block_size,self.halo_size,self.heads,x.device,)assert (h % block == 0 and w % block == 0), "fmap dimensions must be divisible by the block size"assert (c == self.dim), f"channels for input ({c}) does not equal to the correct dimension ({self.dim})"# get block neighborhoods, and prepare a halo-ed version (blocks with padding) for deriving key valuesq_inp = rearrange(x, "b c (h p1) (w p2) -> (b h w) (p1 p2) c", p1=block, p2=block)kv_inp = F.unfold(x, kernel_size=block + halo * 2, stride=block, padding=halo)kv_inp = rearrange(kv_inp, "b (c j) i -> (b i) j c", c=c)# derive queries, keys, valuesq = self.to_q(q_inp)k, v = self.to_kv(kv_inp).chunk(2, dim=-1)# split headsq, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=heads), (q, k, v))# scaleq *= self.scale# attentionsim = einsum("b i d, b j d -> b i j", q, k)# add relative positional biassim += self.rel_pos_emb(q)# mask out padding (in the paper, they claim to not need masks, but what about padding?)mask = torch.ones(1, 1, h, w, device=device)mask = F.unfold(mask, kernel_size=block + (halo * 2), stride=block, padding=halo)mask = repeat(mask, "() j i -> (b i h) () j", b=b, h=heads)mask = mask.bool()max_neg_value = -torch.finfo(sim.dtype).maxsim.masked_fill_(mask, max_neg_value)# attentionattn = sim.softmax(dim=-1)# aggregateout = einsum("b i j, b j d -> b i d", attn, v)# merge and combine headsout = rearrange(out, "(b h) n d -> b n (h d)", h=heads)out = self.to_out(out)# merge blocks back to original feature mapout = rearrange(out,"(b h w) (p1 p2) c -> b c (h p1) (w p2)",b=b,h=(h // block),w=(w // block),p1=block,p2=block,)return outif __name__ == "__main__":input = torch.rand(3, 32, 64, 64).cuda()model = HaloAttention(dim=32,block_size=2,halo_size=1,).cuda()output = model(input)print(input.size(), output.size())

二、添加【HaloAttention】注意力机制

2.1STEP1

首先找到ultralytics/nn文件路径下新建一个Add-module的python文件包【这里注意一定是python文件包,新建后会自动生成_init_.py】,如果已经跟着我的教程建立过一次了可以省略此步骤,随后新建一个HaloAttention.py文件并将上文中提到的注意力机制的代码全部粘贴到此文件中,如下图所示在这里插入图片描述

2.2STEP2

在STEP1中新建的_init_.py文件中导入增加改进模块的代码包如下图所示在这里插入图片描述

2.3STEP3

找到ultralytics/nn文件夹中的task.py文件,在其中按照下图添加在这里插入图片描述

2.4STEP4

定位到ultralytics/nn文件夹中的task.py文件中的def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)函数添加如图代码,【如果不好定位可以直接ctrl+f搜索定位】

在这里插入图片描述

三、yaml文件与运行

3.1yaml文件

以下是添加【HaloAttention】注意力机制在Backbone中的yaml文件,大家可以注释自行调节,效果以自己的数据集结果为准

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'# [depth, width, max_channels]n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPss: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPsm: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPsl: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPsx: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs# YOLO11n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2- [-1, 1, Conv, [128,3,2]] # 1-P2/4- [-1, 2, C3k2, [256, False, 0.25]]- [-1, 1, Conv, [256,3,2]] # 3-P3/8- [-1, 2, C3k2, [512, False, 0.25]]- [-1, 1, Conv, [512,3,2]] # 5-P4/16- [-1, 2, C3k2, [512, True]]- [-1, 1, Conv, [1024,3,2]] # 7-P5/32- [-1, 2, C3k2, [1024, True]]- [-1, 1, HaloAttention, [2, 1]]- [-1, 1, SPPF, [1024, 5]] # 9- [-1, 2, C2PSA, [1024]] # 10# YOLO11n head
head:- [-1, 1, nn.Upsample, [None, 2, "nearest"]]- [[-1, 6], 1, Concat, [1]] # cat backbone P4- [-1, 2, C3k2, [512, False]] # 13- [-1, 1, nn.Upsample, [None, 2, "nearest"]]- [[-1, 4], 1, Concat, [1]] # cat backbone P3- [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 14], 1, Concat, [1]] # cat head P4- [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 11], 1, Concat, [1]] # cat head P5- [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)- [[17, 20, 23], 1, Detect, [nc]] # Detect(P3, P4, P5)

以上添加位置仅供参考,具体添加位置以及模块效果以自己的数据集结果为准

3.2运行成功截图

在这里插入图片描述

OK 以上就是添加【HaloAttention】注意力机制的全部过程了,后续将持续更新尽情期待

在这里插入图片描述


http://www.ppmy.cn/server/129920.html

相关文章

专利:创新的盾牌与创新者的利器

在当今快速发展的科技时代,创新已成为推动社会进步和经济增长的关键因素。专利,作为保护创新成果的重要法律工具,对于激励创新、保护发明人权益以及促进技术进步具有至关重要的作用。 专利的定义与类型 专利是政府授予发明人对其发明创造在…

web网页项目--用户登录,注册页面代码

index.html <!DOCTYPE html> <html lang"zxx"><head><title>xxx注册</title><!-- Meta tags --><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0&q…

论文翻译 | Fairness-guided Few-shot Prompting for LargeLanguage Models

摘要 大型语言模型已经显示出令人惊讶的执行上下文学习的能力&#xff0c;也就是说&#xff0c;这些模型可以通过对由几个输入输出示例构建的提示进行条件反射&#xff0c;直接应用于解决大量下游任务。然而&#xff0c;先前的研究表明&#xff0c;由于训练示例、示例顺序和提示…

需求9——通过一个小需求来体会service层的作用

昨天在完成了睿哥的需求验收之后&#xff0c;暂时没有其他任务&#xff0c;因此今天可能会比较有空闲时间。趁着这个机会&#xff0c;我打算把之前完成的一些需求进行总结&#xff0c;方便以后复习和参考。 在8月份的时候&#xff0c;我负责了一个需求&#xff0c;该需求的具体…

国内的无人机行业的现状和前景分析

近年来&#xff0c;随着科技的飞速发展&#xff0c;无人机&#xff08;Unmanned Aerial Vehicle, UAV&#xff09;作为战略性新兴产业的重要组成部分&#xff0c;在全球范围内迅速崛起。无人机利用无线电遥控设备和自备的程序控制装置操纵&#xff0c;实现不载人飞行&#xff0…

SQL进阶技巧:统计各时段观看直播的人数

目录 0 需求描述 1 数据准备 2 问题分析 3 小结 如果觉得本文对你有帮助&#xff0c;那么不妨也可以选择去看看我的博客专栏 &#xff0c;部分内容如下&#xff1a; 数字化建设通关指南 专栏 原价99&#xff0c;现在活动价39.9&#xff0c;十一国庆后将上升至59.9&#…

国外火出圈儿的PM御用AI编程工具Bolt.new效果干不过国产的CodeFlying?号称全新定义全栈开发流程?

不知道大家最近有没有发现国外的很多AI都在挤破脑袋想去提升大模型的编程能力&#xff0c; 离我们最近的是上周Openai 发布的全新模型GPT-4o-Canvas&#xff0c; 拥有超强的代码编写能力。 另外还有LlamaCoder、Cursor、Claude artifacts、Replit... 光是今年一年就推出了好…

Hive优化操作(二)

Hive 数据倾斜优化 在使用 Hive 进行大数据处理时&#xff0c;数据倾斜是一个常见的问题。本文将详细介绍数据倾斜的概念、表现、常见场景及其解决方案。 1. 什么是数据倾斜&#xff1f; 数据倾斜是指由于数据分布不均匀&#xff0c;导致大量数据集中到某个节点或任务中&…