GiraffeDet助力yolov8暴涨分---有可执行源码

news/2024/12/28 1:00:10/

Yolov8魔改–加入GiraffeDet模型提高小目标效果
VX搜索晓理紫关注并回复有yolov8-GiraffeDet获取代码
[晓理紫]

1 GiraffeDet模型

GiraffeDet是一种新颖的粗颈范例,一种类似长颈鹿的网络,用于高效的目标检测。 GiraffeDet 使用极其轻量的主干和非常深且大的颈部模块,鼓励不同空间尺度以及不同级别的潜在语义同时进行密集的信息交换。 这种设计范式使得检测器即使在网络的早期阶段也可以以相同的优先级处理高层语义信息和低层空间信息,使其在检测任务中更加有效。 对多个流行目标检测基准的数值评估表明,GiraffeDet 在各种资源限制下始终优于以前的 SOTA 模型。网络源码
在这里插入图片描述

2 yolov8引入GiraffeDet

为了提高yolov8对小目标的检测效果,可以在yolov8中引入GiraffeDet网络,在大部分数据集中可以有不错的效果。引入方法如下。

2.1 加入GiraffeDet模型

在ultralytics/nn/modules/中创建module_GiraffeDet.py,并把下面代码写入

import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = 'RepConv', 'Swish', 'ConvBNAct', 'BasicBlock_3x3_Reverse', 'SPP','CSPStage'
def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):'''Basic cell for rep-style block, including conv and bn'''result = nn.Sequential()result.add_module('conv',nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,stride=stride,padding=padding,groups=groups,bias=False))result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))return resultclass RepConv(nn.Module):'''RepConv is a basic rep-style block, including training and deploy statusCode is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py'''def __init__(self,in_channels,out_channels,kernel_size=3,stride=1,padding=1,dilation=1,groups=1,padding_mode='zeros',deploy=False,act='relu',norm=None):super(RepConv, self).__init__()self.deploy = deployself.groups = groupsself.in_channels = in_channelsself.out_channels = out_channelsassert kernel_size == 3assert padding == 1padding_11 = padding - kernel_size // 2if isinstance(act, str):self.nonlinearity = get_activation(act)else:self.nonlinearity = actif deploy:self.rbr_reparam = nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,stride=stride,padding=padding,dilation=dilation,groups=groups,bias=True,padding_mode=padding_mode)else:self.rbr_identity = Noneself.rbr_dense = conv_bn(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,stride=stride,padding=padding,groups=groups)self.rbr_1x1 = conv_bn(in_channels=in_channels,out_channels=out_channels,kernel_size=1,stride=stride,padding=padding_11,groups=groups)def forward(self, inputs):'''Forward process'''if hasattr(self, 'rbr_reparam'):return self.nonlinearity(self.rbr_reparam(inputs))if self.rbr_identity is None:id_out = 0else:id_out = self.rbr_identity(inputs)return self.nonlinearity(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)def get_equivalent_kernel_bias(self):kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasiddef _pad_1x1_to_3x3_tensor(self, kernel1x1):if kernel1x1 is None:return 0else:return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])def _fuse_bn_tensor(self, branch):if branch is None:return 0, 0if isinstance(branch, nn.Sequential):kernel = branch.conv.weightrunning_mean = branch.bn.running_meanrunning_var = branch.bn.running_vargamma = branch.bn.weightbeta = branch.bn.biaseps = branch.bn.epselse:assert isinstance(branch, nn.BatchNorm2d)if not hasattr(self, 'id_tensor'):input_dim = self.in_channels // self.groupskernel_value = np.zeros((self.in_channels, input_dim, 3, 3),dtype=np.float32)for i in range(self.in_channels):kernel_value[i, i % input_dim, 1, 1] = 1self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)kernel = self.id_tensorrunning_mean = branch.running_meanrunning_var = branch.running_vargamma = branch.weightbeta = branch.biaseps = branch.epsstd = (running_var + eps).sqrt()t = (gamma / std).reshape(-1, 1, 1, 1)return kernel * t, beta - running_mean * gamma / stddef switch_to_deploy(self):if hasattr(self, 'rbr_reparam'):returnkernel, bias = self.get_equivalent_kernel_bias()self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels,out_channels=self.rbr_dense.conv.out_channels,kernel_size=self.rbr_dense.conv.kernel_size,stride=self.rbr_dense.conv.stride,padding=self.rbr_dense.conv.padding,dilation=self.rbr_dense.conv.dilation,groups=self.rbr_dense.conv.groups,bias=True)self.rbr_reparam.weight.data = kernelself.rbr_reparam.bias.data = biasfor para in self.parameters():para.detach_()self.__delattr__('rbr_dense')self.__delattr__('rbr_1x1')if hasattr(self, 'rbr_identity'):self.__delattr__('rbr_identity')if hasattr(self, 'id_tensor'):self.__delattr__('id_tensor')self.deploy = Trueclass Swish(nn.Module):def __init__(self, inplace=True):super(Swish, self).__init__()self.inplace = inplacedef forward(self, x):if self.inplace:x.mul_(F.sigmoid(x))return xelse:return x * F.sigmoid(x)def get_activation(name='silu', inplace=True):if name is None:return nn.Identity()if isinstance(name, str):if name == 'silu':module = nn.SiLU(inplace=inplace)elif name == 'relu':module = nn.ReLU(inplace=inplace)elif name == 'lrelu':module = nn.LeakyReLU(0.1, inplace=inplace)elif name == 'swish':module = Swish(inplace=inplace)elif name == 'hardsigmoid':module = nn.Hardsigmoid(inplace=inplace)elif name == 'identity':module = nn.Identity()else:raise AttributeError('Unsupported act type: {}'.format(name))return moduleelif isinstance(name, nn.Module):return nameelse:raise AttributeError('Unsupported act type: {}'.format(name))def get_norm(name, out_channels, inplace=True):if name == 'bn':module = nn.BatchNorm2d(out_channels)else:raise NotImplementedErrorreturn moduleclass ConvBNAct(nn.Module):"""A Conv2d -> Batchnorm -> silu/leaky relu block"""def __init__(self,in_channels,out_channels,ksize,stride=1,groups=1,bias=False,act='silu',norm='bn',reparam=False,):super().__init__()# same paddingpad = (ksize - 1) // 2self.conv = nn.Conv2d(in_channels,out_channels,kernel_size=ksize,stride=stride,padding=pad,groups=groups,bias=bias,)if norm is not None:self.bn = get_norm(norm, out_channels, inplace=True)if act is not None:self.act = get_activation(act, inplace=True)self.with_norm = norm is not Noneself.with_act = act is not Nonedef forward(self, x):x = self.conv(x)if self.with_norm:x = self.bn(x)if self.with_act:x = self.act(x)return xdef fuseforward(self, x):return self.act(self.conv(x))class BasicBlock_3x3_Reverse(nn.Module):def __init__(self,ch_in,ch_hidden_ratio,ch_out,act='relu',shortcut=True):super(BasicBlock_3x3_Reverse, self).__init__()assert ch_in == ch_outch_hidden = int(ch_in * ch_hidden_ratio)self.conv1 = ConvBNAct(ch_hidden, ch_out, 3, stride=1, act=act)self.conv2 = RepConv(ch_in, ch_hidden, 3, stride=1, act=act)self.shortcut = shortcutdef forward(self, x):y = self.conv2(x)y = self.conv1(y)if self.shortcut:return x + yelse:return yclass SPP(nn.Module):def __init__(self,ch_in,ch_out,k,pool_size,act='swish',):super(SPP, self).__init__()self.pool = []for i, size in enumerate(pool_size):pool = nn.MaxPool2d(kernel_size=size,stride=1,padding=size // 2,ceil_mode=False)self.add_module('pool{}'.format(i), pool)self.pool.append(pool)self.conv = ConvBNAct(ch_in, ch_out, k, act=act)def forward(self, x):outs = [x]for pool in self.pool:outs.append(pool(x))y = torch.cat(outs, axis=1)y = self.conv(y)return yclass CSPStage(nn.Module):def __init__(self,ch_in,ch_out,n=1,block_fn='BasicBlock_3x3_Reverse',ch_hidden_ratio=1.0,act='silu',spp=False):super(CSPStage, self).__init__()split_ratio = 2ch_first = int(ch_out // split_ratio)ch_mid = int(ch_out - ch_first)self.conv1 = ConvBNAct(ch_in, ch_first, 1, act=act)self.conv2 = ConvBNAct(ch_in, ch_mid, 1, act=act)self.convs = nn.Sequential()next_ch_in = ch_midfor i in range(n):if block_fn == 'BasicBlock_3x3_Reverse':self.convs.add_module(str(i),BasicBlock_3x3_Reverse(next_ch_in,ch_hidden_ratio,ch_mid,act=act,shortcut=True))else:raise NotImplementedErrorif i == (n - 1) // 2 and spp:self.convs.add_module('spp', SPP(ch_mid * 4, ch_mid, 1, [5, 9, 13], act=act))next_ch_in = ch_midself.conv3 = ConvBNAct(ch_mid * n + ch_first, ch_out, 1, act=act)def forward(self, x):y1 = self.conv1(x)y2 = self.conv2(x)mid_out = [y1]for conv in self.convs:y2 = conv(y2)mid_out.append(y2)y = torch.cat(mid_out, axis=1)y = self.conv3(y)return y

2.2 修改ultralytics/nn/modules/_init_.py文件

from .module_GiraffeDet import(CSPStage)
__all__ = ('Conv', 'Conv2', 'LightConv', 'RepConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus','GhostConv', 'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'TransformerLayer','TransformerBlock', 'MLPBlock', 'LayerNorm2d', 'DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3','C2f', 'C3x', 'C3TR', 'C3Ghost', 'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'Detect','Segment', 'Pose', 'Classify', 'TransformerEncoderLayer', 'RepC3', 'RTDETRDecoder', 'AIFI','DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP','CSPStage')

2.3 tasks.py注册(ultralytics/nn/tasks.py)

from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,GhostBottleneck, GhostConv, Segment, CSPStage)

`
在tasks.py的parse_model函数666行由

        n = n_ = max(round(n * depth), 1) if n > 1 else n  # depth gainif m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3):

变为


n = n_ = max(round(n * depth), 1) if n > 1 else n  # depth gainif m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, CSPStage):

2.4 4、修改yolov8_GFPN.yaml

# Ultralytics YOLO 🚀, GPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 4  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 1, SPPF, [1024, 5]]  # 9# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 6], 1, Concat, [1]]  # cat backbone P4- [-1, 3, CSPStage, [512]]  # 12- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 4], 1, Concat, [1]]  # cat backbone P3- [-1, 3, CSPStage, [256]]  # 15 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 12], 1, Concat, [1]]  # cat head P4- [-1, 3, CSPStage, [512]]  # 18 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 9], 1, Concat, [1]]  # cat head P5- [-1, 3, CSPStage, [1024]]  # 21 (P5/32-large)- [[15, 18, 21], 1, Detect, [nc]]  # Detect(P3, P4, P5)

3 训练

3.1 环境配置

创建虚拟环境重新编译ultralytics并安装
pip3 install -r requirements.txt
python3 setup.py install

3.2 开始训练

yolo task=detect mode=train model=./ultralytics/cfg/models/v8/yolov8-GFPN.yaml pretrained=yolov8x.pt data=./ultralytics/cfg/datasets/data.yaml batch=36 epochs=1000 imgsz=640 workers=16 device=0 nbs=4

4 代码获取方式

VX搜索晓理紫关注并回复有yolov8-GiraffeDet获取代码

{晓理紫}喜分享,也很需要你的支持,喜欢留下痕迹哦!


http://www.ppmy.cn/news/1073786.html

相关文章

C语言二——依次将10个数输入,要求将其中最大的数输出

这是一个简单的C语言程序,它会接受用户输入的10个整数,然后找出最大值并输出。 程序的执行步骤如下: 声明一个数组 n,用于存储用户输入的10个整数,声明一个变量 i 和 t。提示用户输入10个数。使用 for 循环&#xff…

你觉得 Android 还有必要继续吗?

前言 这些年,总是听到有人说Android 开发岗位要凉了,不好做了。坦白说,市场倾向理性,竞争变强是很正常的事。但你发现总有些人,他们拿的 Offer 薪资是更高的,能达到年薪五六十万,甚至年薪百万。…

SpringBoot常用的简化开发注解

一、引言 在Spring Boot框架中,有许多常用的注解可用于开发项目。下面是其中一些常见的注解及其功能和属性的说明: 1、RestController RestController 是 Spring Framework 中的一个注解,用于标识一个类是 RESTful 服务的控制器。它结合了…

接口幂等性设计的最佳实现

一、什么是幂等 二、为什么需要幂等 三、接口超时了,到底如何处理? 四、如何设计幂等 全局的唯一性ID 幂等设计的基本流程 五、实现幂等的8种方案 selectinsert主键/唯一索引冲突 直接insert 主键/唯一索引冲突 状态机幂等 抽取防重表 token令牌 悲观锁…

Linux 安装mysql(ARM架构)

添加mysql用户组和mysql用户 安装依赖libaio yum install -y libaio* 下载Mysql wget https://obs.cn-north-4.myhuaweicloud.com/obs-mirror-ftp4/database/mysql-5.7.27-aarch64.tar.gz安装mysql 解压Mysql tar xvf mysql-5.7.27-aarch64.tar.gz -C /usr/local/ 重命名 …

SCI论文必备Latex使用技巧【随时更新】

\pi —— π \xi —— ξ \eta —— η \mu —— μ \rho —— ρ \phi —— ϕ \psi —— ψ \zeta —— ζ \beta —— β \delta —— δ \alpha —— α \theta —— θ \sigma —— σ \partial —— ∂ \gamma —— γ \epsilon —— ϵ \lambda —— λ \omega —— ω …

Java9-17新特性

文章目录 一、简介二、新特性接口私有方法(JDK9)String存储结构的变化(JDK9)快速创建只读集合(JDK9、10)文本块(JDK13、14、15)更直观的 NullPointerException 提示(JDK1…

音视频 ffmpeg命令提取PCM数据

提取PCM ffmpeg -i buweishui.mp3 -ar 48000 -ac 2 -f s16le 48000_2_s16le ffmpeg -i buweishui.mp3 -ar 48000 -ac 2 -sample_fmt s16 out_s16.wav ffmpeg -i buweishui.mp3 -ar 48000 -ac 2 -codec:a pcm_s16le out2_s16le.wav ffmpeg -i buweishui.mp3 -ar 48000 -ac 2 -f…