爆改YOLOv8 |YOLOv8融合SEAM注意力机制

devtools/2024/12/22 15:21:41/

1,本文介绍

SEAM(Spatially Enhanced Attention Module)是一个注意力网络模块,旨在解决面部遮挡导致的响应损失问题。通过使用深度可分离卷积和残差连接的组合,SEAM模块增强未遮挡面部的响应。深度可分离卷积在每个通道上独立操作,减少了参数量,但忽略了通道间的关系。为了解决这一问题,SEAM将不同深度卷积的输出通过1x1卷积结合,再通过两层全连接网络融合每个通道的信息,以提升通道间的联系,从而弥补遮挡造成的损失。

本文将讲解如何将SEAM融合进yolov8

话不多说,上代码!

2, 将SEAM融合进yolov8

2.1 步骤一

找到如下的目录'ultralytics/nn/modules',然后在这个目录下创建一个seam.py文件,文件名字可以根据你自己的习惯起,然后将seam的核心代码复制进去

import torch
import torch.nn as nn__all__ = ['SEAM', 'MultiSEAM']class Residual(nn.Module):def __init__(self, fn):super(Residual, self).__init__()self.fn = fndef forward(self, x):return self.fn(x) + xclass SEAM(nn.Module):def __init__(self, c1,  n=1, reduction=16):super(SEAM, self).__init__()c2 = c1self.DCovN = nn.Sequential(# nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1, groups=c1),# nn.GELU(),# nn.BatchNorm2d(c2),*[nn.Sequential(Residual(nn.Sequential(nn.Conv2d(in_channels=c2, out_channels=c2, kernel_size=3, stride=1, padding=1, groups=c2),nn.GELU(),nn.BatchNorm2d(c2))),nn.Conv2d(in_channels=c2, out_channels=c2, kernel_size=1, stride=1, padding=0, groups=1),nn.GELU(),nn.BatchNorm2d(c2)) for i in range(n)])self.avg_pool = torch.nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(c2, c2 // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(c2 // reduction, c2, bias=False),nn.Sigmoid())self._initialize_weights()# self.initialize_layer(self.avg_pool)self.initialize_layer(self.fc)def forward(self, x):b, c, _, _ = x.size()y = self.DCovN(x)y = self.avg_pool(y).view(b, c)y = self.fc(y).view(b, c, 1, 1)y = torch.exp(y)return x * y.expand_as(x)def _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.xavier_uniform_(m.weight, gain=1)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)def initialize_layer(self, layer):if isinstance(layer, (nn.Conv2d, nn.Linear)):torch.nn.init.normal_(layer.weight, mean=0., std=0.001)if layer.bias is not None:torch.nn.init.constant_(layer.bias, 0)def DcovN(c1, c2, depth, kernel_size=3, patch_size=3):dcovn = nn.Sequential(nn.Conv2d(c1, c2, kernel_size=patch_size, stride=patch_size),nn.SiLU(),nn.BatchNorm2d(c2),*[nn.Sequential(Residual(nn.Sequential(nn.Conv2d(in_channels=c2, out_channels=c2, kernel_size=kernel_size, stride=1, padding=1, groups=c2),nn.SiLU(),nn.BatchNorm2d(c2))),nn.Conv2d(in_channels=c2, out_channels=c2, kernel_size=1, stride=1, padding=0, groups=1),nn.SiLU(),nn.BatchNorm2d(c2)) for i in range(depth)])return dcovnclass MultiSEAM(nn.Module):def __init__(self, c1,  depth=1, kernel_size=3, patch_size=[3, 5, 7], reduction=16):super(MultiSEAM, self).__init__()c2 = c1self.DCovN0 = DcovN(c1, c2, depth, kernel_size=kernel_size, patch_size=patch_size[0])self.DCovN1 = DcovN(c1, c2, depth, kernel_size=kernel_size, patch_size=patch_size[1])self.DCovN2 = DcovN(c1, c2, depth, kernel_size=kernel_size, patch_size=patch_size[2])self.avg_pool = torch.nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(c2, c2 // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(c2 // reduction, c2, bias=False),nn.Sigmoid())def forward(self, x):b, c, _, _ = x.size()y0 = self.DCovN0(x)y1 = self.DCovN1(x)y2 = self.DCovN2(x)y0 = self.avg_pool(y0).view(b, c)y1 = self.avg_pool(y1).view(b, c)y2 = self.avg_pool(y2).view(b, c)y4 = self.avg_pool(x).view(b, c)y = (y0 + y1 + y2 + y4) / 4y = self.fc(y).view(b, c, 1, 1)y = torch.exp(y)return x * y.expand_as(x)

2.2 步骤二

首先找到如下的目录'ultralytics/nn/modules',然后在这个目录下找到init文件,在init中添加如下代码.

from .seam import (MultiSEAM,SEAM
)

同时在init.py中的如下位置添加SEAM,MultiSEAM

2.3 步骤三

在task.py中导入SEAM

2.4 步骤四

在task.py中添加如下代码.

到此注册成功,复制后面的yaml文件直接运行即可

有两种yaml文件,可以自行选择

yaml文件一(seam)

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOP# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 1, SPPF, [1024, 5]]  # 9# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 6], 1, Concat, [1]]  # cat backbone P4- [-1, 3, C2f, [512]]  # 12- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 4], 1, Concat, [1]]  # cat backbone P3- [-1, 3, C2f, [256]]  # 15 (P3/8-small)- [-1, 1, SEAM, []]  # 16- [-1, 1, Conv, [256, 3, 2]]- [[-1, 12], 1, Concat, [1]]  # cat head P4- [-1, 3, C2f, [512]]  # 19 (P4/16-medium)- [-1, 1, SEAM, []]  # 20- [-1, 1, Conv, [512, 3, 2]]- [[-1, 9], 1, Concat, [1]]  # cat head P5- [-1, 3, C2f, [1024]]  # 23 (P5/32-large)- [-1, 1, SEAM, []]  # 24- [[16, 20, 24], 1, Detect, [nc]]  # Detect(P3, P4, P5)

yaml文件二(multiseam)


# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOP# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 1, SPPF, [1024, 5]]  # 9# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 6], 1, Concat, [1]]  # cat backbone P4- [-1, 3, C2f, [512]]  # 12- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 4], 1, Concat, [1]]  # cat backbone P3- [-1, 3, C2f, [256]]  # 15 (P3/8-small)- [-1, 1, MultiSEAM, []]  # 16- [-1, 1, Conv, [256, 3, 2]]- [[-1, 12], 1, Concat, [1]]  # cat head P4- [-1, 3, C2f, [512]]  # 19 (P4/16-medium)- [-1, 1, MultiSEAM, []]  # 20- [-1, 1, Conv, [512, 3, 2]]- [[-1, 9], 1, Concat, [1]]  # cat head P5- [-1, 3, C2f, [1024]]  # 23 (P5/32-large)- [-1, 1, MultiSEAM, []]  # 24- [[16, 20, 24], 1, Detect, [nc]]  # Detect(P3, P4, P5)

# 关于SEAM添加的位置还可以放在颈部,针对不同数据集位置不同,效果不同

不知不觉已经看完了哦,动动小手留个点赞吧--_--


http://www.ppmy.cn/devtools/100416.html

相关文章

【C++】初识C++模板与STL

C语法相关知识点可以通过点击以下链接进行学习一起加油!命名空间缺省参数与函数重载C相关特性类和对象-上篇类和对象-中篇类和对象-下篇日期类C/C内存管理 本章将简单分享C模板与STL相关知识,与之相关更多知识将留到下次更详细地来分享给大家 &#x1f3…

CSS的简单介绍

1.什么是CSS CSS(层叠样式表),用于控制页面的样式,简单地来说,CSS就是用来美化页面的一种语言。 2.基本语法规范 CSS的基本语法规范:CSS选择器{1或多条声明} 其中CSS选择器决定找谁(针对哪个元素进行修改),声明决定…

Debian Linux上安装Jumpserver

1.安装 Debian并配置 登录www.debian.io,下载网络版安装,安装很快,但完成后修改IP就遇到问题vi /etc/network/interfaces auto eth0 #设置开机自动连接网络 iface lo inet loopback allow-hotplug eth0 iface eth0 inet static #static表示使用固定I…

HTML 全解析:从基础到实战

一、简介 HTML(HyperText Markup Language)即超文本标记语言,是用于创建网页的标准标记语言。它通过各种标签来定义网页的结构和内容,使得浏览器能够正确地显示网页。HTML 文档由 HTML 元素组成,这些元素通过标签来表…

选择一家正规的应急指挥中心控制台厂家有多关键

在当今社会,随着自然灾害、突发事件及安全挑战的日益复杂多变,应急管理体系的构建显得尤为重要。而应急指挥中心作为应对各类紧急情况的神经中枢,其高效运作离不开先进、可靠的控制台设备支持。在此背景下,选择一家正规的应急指挥…

湖南贝特新能源科技:巧用草料二维码,实现设备管理数字化

在当今快速发展的制造业环境中,设备管理效率直接影响着企业的生产力与竞争力,我司也面临着设备管理流程中的诸多挑战。 如今购物买菜都是扫码支付的时代,我司也与时俱进, 导入了二维码系统进行公司质量、设备、安全、项目、改善等…

python如何快速生成一个密钥

在Python中,快速生成一个密钥通常依赖于内置的库或第三方库来生成一个安全的随机字符串。以下是一些常见的方法来生成密钥: 使用secrets模块 Python 3.6及以上版本引入了secrets模块,它用于生成适合管理密码、账户认证信息、安全令牌等敏感…

【JavaEE】SpringBoot 统一功能处理:拦截器、统一数据返回与异常处理的综合应用与源码解析

目录 SpringBoot 统⼀功能处理拦截器拦截器快速⼊⻔拦截器详解拦截路径拦截器执⾏流程 登录校验定义拦截器注册配置拦截器 DispatcherServlet 源码分析(了解)初始化(了解) DispatcherServlet的初始化1. HttpServletBean.init()2. FrameworkServlet.initServletBean() WebApplic…