Yolov10(yolov8代码里兼容版本)推理代码解析,抛去nms,大道至简

embedded/2024/9/23 6:39:14/

一、模型的输出头

下载官方的yolov8代码库https://github.com/ultralytics/ultralytics
打开ultralytics/nn/modules/head.py,主要需要看一下模型的输出头是如何做训练和预测推理。
在这里插入图片描述
v10检测头继承与常规的检测头Detect,初始化里重构了一下分类的输出头self.cv3,多加了一些卷积层。并将end2end这个参数置为True
在这里插入图片描述
再来看Detect检测头里如何兼容v10检测的

由于end2end是True.
在这里插入图片描述
所以走forward_end2end()

    def forward_end2end(self, x):"""Performs forward pass of the v10Detect module.Args:x (tensor): Input tensor.Returns:(dict, tensor): If not in training mode, returns a dictionary containing the outputs of both one2many and one2one detections.If in training mode, returns a dictionary containing the outputs of one2many and one2one detections separately."""x_detach = [xi.detach() for xi in x]one2one = [torch.cat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1) for i in range(self.nl)]for i in range(self.nl):x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)if self.training:  # Training pathreturn {"one2many": x, "one2one": one2one}y = self._inference(one2one)y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc)return y if self.export else (y, {"one2many": x, "one2one": one2one})

将网络端到端的3个输出头拼接得到one2one
在这里插入图片描述
one2one为1对1训练输出头
x为1对多训练输出头
如果self.training为True,
即你在训练的时候返回一个字典 {“one2many”: x, “one2one”: one2one},用于e2e训练。
在这里插入图片描述
如果是评估或者预测图片,先走推理self._inference再做后处理self.postprocess
在这里插入图片描述
推理只需要获取1对1输出头的结果即可,对box进行编码self.decode_bboxes。

    def _inference(self, x):"""Decode predicted bounding boxes and class probabilities based on multiple-level feature maps."""# Inference pathshape = x[0].shape  # BCHWx_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)if self.dynamic or self.shape != shape:self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))self.shape = shapeif self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}:  # avoid TF FlexSplitV opsbox = x_cat[:, : self.reg_max * 4]cls = x_cat[:, self.reg_max * 4 :]else:box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)if self.export and self.format in {"tflite", "edgetpu"}:# Precompute normalization factor to increase numerical stability# See https://github.com/ultralytics/ultralytics/issues/7371grid_h = shape[2]grid_w = shape[3]grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)norm = self.strides / (self.stride[0] * grid_size)dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])else:dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.stridesreturn torch.cat((dbox, cls.sigmoid()), 1)

return的输出是dbox(4)+cls(14),batch=1,为(1,18,8400)
在这里插入图片描述

二、后处理

在这里插入图片描述

    @staticmethoddef postprocess(preds: torch.Tensor, max_det: int, nc: int = 80):"""Post-processes YOLO model predictions.Args:preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimensionformat [x, y, w, h, class_probs].max_det (int): Maximum detections per image.nc (int, optional): Number of classes. Default: 80.Returns:(torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6) and lastdimension format [x, y, w, h, max_class_prob, class_index]."""batch_size, anchors, predictions = preds.shape  # i.e. shape(16,8400,84)boxes, scores = preds.split([4, nc], dim=-1)index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1)boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4))scores = scores.gather(dim=1, index=index.repeat(1, 1, nc))scores, index = scores.flatten(1).topk(max_det)i = torch.arange(batch_size)[..., None]  # batch indicesreturn torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1)

1、首先获取到批次,预测锚框,和预测box+cls

batch_size, anchors, predictions = preds.shape

在这里插入图片描述
2、单独获取预测box和每个box的所有分类得分

 boxes, scores = preds.split([4, nc], dim=-1)

在这里插入图片描述
3、获取分类得分值最大的前300个框

index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1)

topk()会获取到scores的value和对应的索引indices,这里我们只需要获取到索引即可.topk(min(max_det, anchors))[1]
在这里插入图片描述
在这里插入图片描述
4、根据索引挑出这300个框

boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4))

在这里插入图片描述
5、根据索引挑出300个分类得分,即300个boxes对应的14类别的分类得分.

scores = scores.gather(dim=1, index=index.repeat(1, 1, nc))

在这里插入图片描述
6、挑出所有分类的得分中的前300个框。

scores, index = scores.flatten(1).topk(max_det)

在这里插入图片描述
7、获取batch的索引

i = torch.arange(batch_size)[..., None]  # batch indices

在这里插入图片描述
8、返回300个得分值最高的框

 torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1)

index // nc:得分值前300的索引所属的box,某一个框可能有2个分类得分都在前300,则这个框会出现2次。
scores[…, None]:框的分数
(index % nc)[…, None].float():框的类别

在这里插入图片描述

三、预测输出

ultralytic/models/yolo/detect/predict.py

    def postprocess(self, preds, img, orig_imgs):"""Post-processes predictions and returns a list of Results objects."""preds = ops.non_max_suppression(preds,self.args.conf,self.args.iou,agnostic=self.args.agnostic_nms,max_det=self.args.max_det,classes=self.args.classes,)if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a listorig_imgs = ops.convert_torch2numpy_batch(orig_imgs)results = []for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))return results

v8源码里v10的predict的后处理也走了nms函数但是并没有做nms处理
在这里插入图片描述
nms函数中,如果判断预测结果是属于v10的end2end的模型预测结果,则直接从300个候选框中输出大于置信度conf_thres的框作为最终的输出结果。

四、总结

1、端到端的模型抛弃了复杂的后处理过程,不再需要转模型的时候对齐精度,直拿直用,必定是未来研究的重点趋势。
2、跟同事讨论,她在她的大数据集上测试v10的结果甚至优于v8的训练结果。这可能得出一个结论,当你的数据集足够大且足够干净的情况下,v10的结果反而会更好,当然这需要各位再多多测试了。
3、后续博主也准备把分割关键点旋转框等都修改成v10的端到端模式,敬请期待。


http://www.ppmy.cn/embedded/103018.html

相关文章

【Qt的TS文件转换器】利用Python实现自动化TS文件转换

TS 文件转换器 在开发多语言Qt应用时,管理和更新翻译文件是一项繁琐但必要的任务。这个工具旨在自动化Qt Linguist TS文件的转换过程,支持不同语言之间的转换,特别关注中文变体和其他语言。 目录 🌎背景⭐特性🔒前提条…

鸿蒙Next 单元测试框架——hypium

一 框架概述 单元测试框架(hypium)是HarmonyOS上的测试框架,提供测试用例编写、执行、结果显示能力,用于测试系统或应用接口。 表1 单元测试框架功能特性 二 安装使用 目前hypium以npm包的形式发布, 因此需要在Deveco Studio 工程级package.json内配…

Python爬虫(一文通)

Python爬虫(基本篇) 一:静态页面爬取 Requests库的使用 1)基本概念安装基本代码格式 应用领域:适合处理**静态页面数据和简单的 HTTP 请求响应**。 Requests库的讲解 含义:requests 库是 Python 中一个…

Encoding.UTF8是.NET 中用于处理UTF-8编码的标准编码类

Encoding.UTF8 是 .NET 中用于处理 UTF-8 编码的标准编码类。UTF-8 是一种可变长度的字符编码方案,它可以表示所有 Unicode 字符,并且与 ASCII 兼容。Encoding.UTF8 是 System.Text.Encoding 类的一个静态属性,提供了对 UTF-8 编码和解码的支…

-[meetingbot4ios.AppDelegate window]: unrecognized selector sent to instance

这个错误的困扰了我半天,具体错误如下: *** Terminating app due to uncaught exception NSInvalidArgumentException, reason: -[meetingbot4ios.AppDelegate window]: unrecognized selector sent to instance 0x60000370c0c0 *** First throw call …

设计模式-结构型模式-组合模式

1.组合模式的定义 将对象组合成树形结构以表示整个部分的层次结构,组合模式可以让用户统一对待单个对象和对象的组合;其更像是一种数据结构和算法的抽象,其中数据可以表示成树这种数据结构,业务需求可以通过在树上的递归遍历算法来…

【问题解决】Jenkins的Pipeline无法正常后台启动Jar包

文章目录 问题描述排查Jenkins日志启动流水线观察Jar包启动情况初步推测问题问题原因:Jenkins进程管理机制问题解决:改写启动Jar包命令参考文章 问题描述 执行Jenkins的Pipeline,执行结果显示为成功,但是Java程序没有成功启动 排…