YOLOv9改进策略 | 损失函数篇 | 利用SlideLoss助力YOLOv9有效涨点(附代码 + 完整修改方式)

devtools/2024/9/25 7:33:41/

 一、本文介绍

本文给大家带来的是分类损失 SlideLoss损失函数,我们之前看那的那些IoU都是边界框回归损失,和本文的修改内容并不冲突,所以大家可以知道损失函数分为两种一种是分类损失另一种是边界框回归损失,上一篇文章里面我们总结了过去百分之九十的边界框回归损失的使用方法,本文我们就来介绍几种市面上流行的和最新的分类损失函数,同时在开始讲解之前推荐一下我的专栏,本专栏的内容支持(分类、检测、分割、追踪、关键点检测),专栏目前为限时折扣,欢迎大家订阅本专栏,本专栏每周更新3-5篇最新机制,更有包含我所有改进的文件和交流群提供给大家,本文支持的损失函数共有如下图片所示

欢迎大家订阅我的专栏一起学习YOLO 

 专栏地址:YOLOv9有效涨点专栏-持续复现各种顶会内容-有效涨点-全网改进最全的专栏 

目录

 一、本文介绍

二、原理介绍

三、核心代码

三、使用方式 

四 、本文总结


二、原理介绍

其中绝大多数损失在前面我们都讲过了本文主要讲一下SlidLoss的原理,SlideLoss的损失首先是由YOLO-FaceV2提出来的。

​​

官方论文地址: 官方论文地址点击即可跳转

官方代码地址: 官方代码地址点击即可跳转

​​


从摘要上我们可以看出SLideLoss的出现是通过权重函数来解决简单和困难样本之间的不平衡问题题,什么是简单样本和困难样本?

样本不平衡问题是一个常见的问题,尤其是在分类和目标检测任务中。它通常指的是训练数据集中不同类别的样本数量差异很大。对于人脸检测这样的任务来说,简单样本和困难样本之间的不平衡问题可以具体描述如下:

简单样本:

  • 容易被模型正确识别的样本。
  • 通常出现在数据集中的数量较多。
  • 特征明显,分类或检测边界清晰。
  • 在训练中,这些样本会给出较低的损失值,因为模型可以轻易地正确预测它们。

困难样本:

  • 模型难以正确识别的样本。
  • 在数据集中相对较少,但对模型性能的提升至关重要。
  • 可能由于多种原因变得难以识别,如遮挡、变形、模糊、光照变化、小尺寸或者与背景的低对比度。
  • 在训练中,这些样本会产生较高的损失值,因为模型很难对它们给出准确的预测。

解决样本不平衡的问题是提高模型泛化能力的关键。如果模型大部分只见过简单样本,它可能在实际应用中遇到困难样本时性能下降。因此采用各种策略来解决这个问题,例如重采样(对困难样本进行过采样或对简单样本进行欠采样)、修改损失函数(给困难样本更高的权重),或者是设计新的模型结构来专门关注困难样本。在YOLO-FaceV2中,作者通过Slide Loss这样的权重函数来让模型在训练过程中更关注那些困难样本(这也是本文的修改内容)


三、核心代码

使用方式看章节

import math
class SlideLoss(nn.Module):def __init__(self, loss_fcn):super(SlideLoss, self).__init__()self.loss_fcn = loss_fcnself.reduction = loss_fcn.reductionself.loss_fcn.reduction = 'none'  # required to apply SL to each elementdef forward(self, pred, true, auto_iou=0.5):loss = self.loss_fcn(pred, true)if auto_iou < 0.2:auto_iou = 0.2b1 = true <= auto_iou - 0.1a1 = 1.0b2 = (true > (auto_iou - 0.1)) & (true < auto_iou)a2 = math.exp(1.0 - auto_iou)b3 = true >= auto_ioua3 = torch.exp(-(true - 1.0))modulating_weight = a1 * b1 + a2 * b2 + a3 * b3loss *= modulating_weightif self.reduction == 'mean':return loss.mean()elif self.reduction == 'sum':return loss.sum()else:  # 'none'return loss


三、使用方式 

根据我下面的图片进行修改即可。


3.1 修改一

我们将上面的核心代码,我们找到如下的文件'utils/loss_tal_dual.py'文件,然后将我们上面的核心代码复制粘贴到文件的开头,注意是文件导入之后!


3.2 修改二 

同一个文件我门下拉,按照下面的图片进行修改即可!

我把代码给copy下来了大家可以复制替换可以! 这个函数修改看不到显示,但是可以debug大家看看执行到没有就行!

class ComputeLoss:# Compute lossesdef __init__(self, model, use_dfl=True):device = next(model.parameters()).device  # get model deviceh = model.hyp  # hyperparameters# Define criteriaBCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h["cls_pw"]], device=device), reduction='none')# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3self.cp, self.cn = smooth_BCE(eps=h.get("label_smoothing", 0.0))  # positive, negative BCE targets# Focal lossg = h["fl_gamma"]  # focal loss gammaif g > 0:BCEcls = FocalLoss(BCEcls, g)BCEcls  = SlideLoss(BCEcls) # 添加这一行代码代表打开了SlideLossm = de_parallel(model).model[-1]  # Detect() moduleself.balance = {3: [4.0, 1.0, 0.4]}.get(m.nl, [4.0, 1.0, 0.25, 0.06, 0.02])  # P3-P7self.BCEcls = BCEclsself.hyp = hself.stride = m.stride  # model stridesself.nc = m.nc  # number of classesself.nl = m.nl  # number of layersself.no = m.noself.reg_max = m.reg_maxself.device = deviceself.assigner = TaskAlignedAssigner(topk=int(os.getenv('YOLOM', 10)),num_classes=self.nc,alpha=float(os.getenv('YOLOA', 0.5)),beta=float(os.getenv('YOLOB', 6.0)))self.assigner2 = TaskAlignedAssigner(topk=int(os.getenv('YOLOM', 10)),num_classes=self.nc,alpha=float(os.getenv('YOLOA', 0.5)),beta=float(os.getenv('YOLOB', 6.0)))self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=use_dfl).to(device)self.bbox_loss2 = BboxLoss(m.reg_max - 1, use_dfl=use_dfl).to(device)self.proj = torch.arange(m.reg_max).float().to(device)  # / 120.0self.use_dfl = use_dfldef preprocess(self, targets, batch_size, scale_tensor):if targets.shape[0] == 0:out = torch.zeros(batch_size, 0, 5, device=self.device)else:i = targets[:, 0]  # image index_, counts = i.unique(return_counts=True)out = torch.zeros(batch_size, counts.max(), 5, device=self.device)for j in range(batch_size):matches = i == jn = matches.sum()if n:out[j, :n] = targets[matches, 1:]out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))return outdef bbox_decode(self, anchor_points, pred_dist):if self.use_dfl:b, a, c = pred_dist.shape  # batch, anchors, channelspred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))# pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))# pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)return dist2bbox(pred_dist, anchor_points, xywh=False)def __call__(self, p, targets, img=None, epoch=0):loss = torch.zeros(3, device=self.device)  # box, cls, dflfeats = p[1][0] if isinstance(p, tuple) else p[0]feats2 = p[1][1] if isinstance(p, tuple) else p[1]pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split((self.reg_max * 4, self.nc), 1)pred_scores = pred_scores.permute(0, 2, 1).contiguous()pred_distri = pred_distri.permute(0, 2, 1).contiguous()pred_distri2, pred_scores2 = torch.cat([xi.view(feats2[0].shape[0], self.no, -1) for xi in feats2], 2).split((self.reg_max * 4, self.nc), 1)pred_scores2 = pred_scores2.permute(0, 2, 1).contiguous()pred_distri2 = pred_distri2.permute(0, 2, 1).contiguous()dtype = pred_scores.dtypebatch_size, grid_size = pred_scores.shape[:2]imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]  # image size (h,w)anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)# targetstargets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])gt_labels, gt_bboxes = targets.split((1, 4), 2)  # cls, xyxymask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)# pboxespred_bboxes = self.bbox_decode(anchor_points, pred_distri)  # xyxy, (b, h*w, 4)pred_bboxes2 = self.bbox_decode(anchor_points, pred_distri2)  # xyxy, (b, h*w, 4)target_labels, target_bboxes, target_scores, fg_mask = self.assigner(pred_scores.detach().sigmoid(),(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),anchor_points * stride_tensor,gt_labels,gt_bboxes,mask_gt)target_labels2, target_bboxes2, target_scores2, fg_mask2 = self.assigner2(pred_scores2.detach().sigmoid(),(pred_bboxes2.detach() * stride_tensor).type(gt_bboxes.dtype),anchor_points * stride_tensor,gt_labels,gt_bboxes,mask_gt)target_bboxes /= stride_tensortarget_scores_sum = max(target_scores.sum(), 1)target_bboxes2 /= stride_tensortarget_scores_sum2 = max(target_scores2.sum(), 1)# cls loss# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum  # VFL wayloss[1] = self.BCEcls(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCEloss[1] *= 0.25loss[1] += self.BCEcls(pred_scores2, target_scores2.to(dtype)).sum() / target_scores_sum2 # BCE# bbox lossif fg_mask.sum():loss[0], loss[2], iou = self.bbox_loss(pred_distri,pred_bboxes,anchor_points,target_bboxes,target_scores,target_scores_sum,fg_mask)loss[0] *= 0.25loss[2] *= 0.25if fg_mask2.sum():loss0_, loss2_, iou2 = self.bbox_loss2(pred_distri2,pred_bboxes2,anchor_points,target_bboxes2,target_scores2,target_scores_sum2,fg_mask2)loss[0] += loss0_loss[2] += loss2_loss[0] *= 7.5  # box gainloss[1] *= 0.5  # cls gainloss[2] *= 1.5  # dfl gainreturn loss.sum() * batch_size, loss.detach()  # loss(box, cls, dfl)


四 、本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv9改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

希望大家阅读完以后可以给文章点点赞和评论支持一下这样购买专栏的人越多群内人越多大家交流的机会就更多了。  

 专栏地址:YOLOv9有效涨点专栏-持续复现各种顶会内容-有效涨点-全网改进最全的专栏 

​​


http://www.ppmy.cn/devtools/48512.html

相关文章

【春秋云镜】Faculty Evaluation System未授权任意文件上传漏洞(CVE-2023-33440)

因为该靶场没有Write up,索性自己搞一下&#xff0c;方便别人&#xff0c;快乐自己&#xff01; 漏洞概述&#xff1a; Sourcecodester Faculty Evaluation System v1.0 is vulnerable to arbitrary code execution via /eval/ajax.php?actionsave_user. 漏洞复现&#xff…

Web前端与REST API:深度解析与实战指南

Web前端与REST API&#xff1a;深度解析与实战指南 在Web开发领域&#xff0c;前端与后端之间的数据交互至关重要&#xff0c;而REST API作为连接两者的桥梁&#xff0c;扮演着不可或缺的角色。本文将从四个方面、五个方面、六个方面和七个方面&#xff0c;深入剖析Web前端与R…

2021年CSP-J-T3-网络连接(network)

2021年CSP-J-T3-网络连接&#xff08;network&#xff09; 题目&#xff1a; 示例2 输入 10 Server 192.168.1.1:80 Client 192.168.1.1:80 Client 192.168.1.1:8080 Server 192.168.1.1:80 Server 192.168.1.1:8080 Server 192.168.1.999:0 Client 192.168.1.1.8080 Client …

贪吃蛇双人模式设计(2)

敲上瘾-CSDN博客控制台程序设置_c语言控制程序窗口大小-CSDN博客贪吃蛇小游戏_贪吃蛇小游戏csdn-CSDN博客 一、功能实现&#xff1a; 玩家1使用↓ → ← ↑按键来操作蛇的方向&#xff0c;使用右Shift键加速&#xff0c;右Ctrl键减速玩家2使用W A S D按键来操作蛇的方向&am…

【代码随想录】【算法训练营】【第30天 1】 [322]重新安排行程 [51]N皇后

前言 思路及算法思维&#xff0c;指路 代码随想录。 题目来自 LeetCode。 day 30&#xff0c;周四&#xff0c;好难&#xff0c;会不了一点~ 题目详情 [322] 重新安排行程 题目描述 322 重新安排行程 解题思路 前提&#xff1a;…… 思路&#xff1a;回溯。 重点&…

Rust 实战丨通过实现 json! 掌握声明宏

在 Rust 编程语言中&#xff0c;宏是一种强大的工具&#xff0c;可以用于在编译时生成代码。json! 是一个在 Rust 中广泛使用的宏&#xff0c;它允许我们在 Rust 代码中方便地创建 JSON 数据。 声明宏&#xff08;declarative macros&#xff09;是 Rust 中的一种宏&#xff0…

web安全-前端层面

参考资料引荐 https://blog.csdn.net/hack0919/article/details/130929154 XSS 简介 跨站脚本攻击(Cross-Site Scripting, 简称XSS)当用户将恶意代码注入网页时&#xff0c;其他用户在浏览网页时就会受到影响攻击主要方向主要用于盗取cookie凭据&#xff0c;钓鱼攻击&#…

pdf压缩到指定大小的简单方法

压缩PDF文件是许多人在日常工作和学习中经常需要面对的问题。PDF文件因其跨平台、易阅读的特性而广受欢迎&#xff0c;但有时候文件体积过大&#xff0c;会给传输和存储带来不便。因此&#xff0c;学会如何有效地压缩PDF文件&#xff0c;就显得尤为重要。本文将详细介绍几种常见…