YOLO物体检测-系列教程6:YOLOV3源码解读4之 YOLO层

news/2024/11/20 15:33:08/

🎈🎈🎈YOLO 系列教程 总目录

上篇内容:

YOLOV3项目实战1之 整体介绍与数据处理

YOLOV3提出论文:《Yolov3: An incremental improvement》

6、yolo层

6.1 yolo层

class YOLOLayer(nn.Module):"""Detection layer"""def __init__(self, anchors, num_classes, img_dim=416):def compute_grid_offsets(self, grid_size, cuda=True):def forward(self, x, targets=None, img_dim=None):

6.2 构造函数

    def __init__(self, anchors, num_classes, img_dim=416):super(YOLOLayer, self).__init__()self.anchors = anchorsself.num_anchors = len(anchors)self.num_classes = num_classesself.ignore_thres = 0.5self.mse_loss = nn.MSELoss()self.bce_loss = nn.BCELoss()self.obj_scale = 1self.noobj_scale = 100self.metrics = {}self.img_dim = img_dimself.grid_size = 0  # grid size

6.3 偏移量计算

    def compute_grid_offsets(self, grid_size, cuda=True):self.grid_size = grid_sizeg = self.grid_sizeFloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensorself.stride = self.img_dim / self.grid_size# Calculate offsets for each gridself.grid_x = torch.arange(g).repeat(g, 1).view([1, 1, g, g]).type(FloatTensor)self.grid_y = torch.arange(g).repeat(g, 1).t().view([1, 1, g, g]).type(FloatTensor)self.scaled_anchors = FloatTensor([(a_w / self.stride, a_h / self.stride) for a_w, a_h in self.anchors])self.anchor_w = self.scaled_anchors[:, 0:1].view((1, self.num_anchors, 1, 1))self.anchor_h = self.scaled_anchors[:, 1:2].view((1, self.num_anchors, 1, 1))

6.4 前向传播

    def forward(self, x, targets=None, img_dim=None):# Tensors for cuda supportprint (x.shape)FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensorLongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensorByteTensor = torch.cuda.ByteTensor if x.is_cuda else torch.ByteTensorself.img_dim = img_dimnum_samples = x.size(0)grid_size = x.size(2)prediction = (x.view(num_samples, self.num_anchors, self.num_classes + 5, grid_size, grid_size).permute(0, 1, 3, 4, 2).contiguous())print (prediction.shape)# Get outputsx = torch.sigmoid(prediction[..., 0])  # Center xy = torch.sigmoid(prediction[..., 1])  # Center yw = prediction[..., 2]  # Widthh = prediction[..., 3]  # Heightpred_conf = torch.sigmoid(prediction[..., 4])  # Confpred_cls = torch.sigmoid(prediction[..., 5:])  # Cls pred.# If grid size does not match current we compute new offsetsif grid_size != self.grid_size:self.compute_grid_offsets(grid_size, cuda=x.is_cuda) #相对位置得到对应的绝对位置比如之前的位置是0.5,0.5变为 11.5,11.5这样的# Add offset and scale with anchors #特征图中的实际位置pred_boxes = FloatTensor(prediction[..., :4].shape)pred_boxes[..., 0] = x.data + self.grid_xpred_boxes[..., 1] = y.data + self.grid_ypred_boxes[..., 2] = torch.exp(w.data) * self.anchor_wpred_boxes[..., 3] = torch.exp(h.data) * self.anchor_houtput = torch.cat( (pred_boxes.view(num_samples, -1, 4) * self.stride, #还原到原始图中pred_conf.view(num_samples, -1, 1),pred_cls.view(num_samples, -1, self.num_classes),),-1,)if targets is None:return output, 0else:iou_scores, class_mask, obj_mask, noobj_mask, tx, ty, tw, th, tcls, tconf = build_targets(pred_boxes=pred_boxes,pred_cls=pred_cls,target=targets,anchors=self.scaled_anchors,ignore_thres=self.ignore_thres,)# iou_scores:真实值与最匹配的anchor的IOU得分值 class_mask:分类正确的索引  obj_mask:目标框所在位置的最好anchor置为1 noobj_mask obj_mask那里置0,还有计算的iou大于阈值的也置0,其他都为1 tx, ty, tw, th, 对应的对于该大小的特征图的xywh目标值也就是我们需要拟合的值 tconf 目标置信度# Loss : Mask outputs to ignore non-existing objects (except with conf. loss)loss_x = self.mse_loss(x[obj_mask], tx[obj_mask]) # 只计算有目标的loss_y = self.mse_loss(y[obj_mask], ty[obj_mask])loss_w = self.mse_loss(w[obj_mask], tw[obj_mask])loss_h = self.mse_loss(h[obj_mask], th[obj_mask])loss_conf_obj = self.bce_loss(pred_conf[obj_mask], tconf[obj_mask]) loss_conf_noobj = self.bce_loss(pred_conf[noobj_mask], tconf[noobj_mask])loss_conf = self.obj_scale * loss_conf_obj + self.noobj_scale * loss_conf_noobj #有物体越接近1越好 没物体的越接近0越好loss_cls = self.bce_loss(pred_cls[obj_mask], tcls[obj_mask]) #分类损失total_loss = loss_x + loss_y + loss_w + loss_h + loss_conf + loss_cls #总损失# Metricscls_acc = 100 * class_mask[obj_mask].mean()conf_obj = pred_conf[obj_mask].mean()conf_noobj = pred_conf[noobj_mask].mean()conf50 = (pred_conf > 0.5).float()iou50 = (iou_scores > 0.5).float()iou75 = (iou_scores > 0.75).float()detected_mask = conf50 * class_mask * tconfprecision = torch.sum(iou50 * detected_mask) / (conf50.sum() + 1e-16)recall50 = torch.sum(iou50 * detected_mask) / (obj_mask.sum() + 1e-16)recall75 = torch.sum(iou75 * detected_mask) / (obj_mask.sum() + 1e-16)self.metrics = {"loss": to_cpu(total_loss).item(),"x": to_cpu(loss_x).item(),"y": to_cpu(loss_y).item(),"w": to_cpu(loss_w).item(),"h": to_cpu(loss_h).item(),"conf": to_cpu(loss_conf).item(),"cls": to_cpu(loss_cls).item(),"cls_acc": to_cpu(cls_acc).item(),"recall50": to_cpu(recall50).item(),"recall75": to_cpu(recall75).item(),"precision": to_cpu(precision).item(),"conf_obj": to_cpu(conf_obj).item(),"conf_noobj": to_cpu(conf_noobj).item(),"grid_size": grid_size,}return output, total_loss

http://www.ppmy.cn/news/1112790.html

相关文章

C++中嵌入汇编语言的方法

C语言和汇编语言相互调用 不同的语言就像一座孤岛,似乎毫不相干,但是所有的代码最终都要编译成机器指令,他们本质上也是一样的,最终都是变成指令给CPU下达命令。 1. C语言的链接过程 我们知道一个C语言源文件变成可执行文件&…

elasticsearch1

个人名片: 博主:酒徒ᝰ. 个人简介:沉醉在酒中,借着一股酒劲,去拼搏一个未来。 本篇励志:三人行,必有我师焉。 本项目基于B站黑马程序员Java《SpringCloud微服务技术栈》,SpringCloud…

Unity SteamVR 开发教程:用摇杆/触摸板控制人物持续移动(2.x 以上版本)

文章目录 📕教程说明📕场景搭建📕创建移动的动作📕移动脚本⭐移动⭐实时调整 CharacterController 的高度 📕取消手部和 CharacterController 的碰撞 持续移动是 VR 开发中的一个常用功能。一般是用户推动手柄摇杆&…

虹科分享 | 软件供应链攻击如何工作?如何评估软件供应链安全?

说到应用程序和软件,关键词是“更多”。在数字经济需求的推动下,从简化业务运营到创造创新的新收入机会,企业越来越依赖应用程序。云本地应用程序开发更是火上浇油。然而,情况是双向的:这些应用程序通常更复杂&#xf…

Kotlin Android中错误及异常处理最佳实践

Kotlin Android中错误及异常处理最佳实践 Kotlin在Android开发中的错误处理机制以及其优势 Kotlin具有强大的错误处理功能:Kotlin提供了强大的错误处理功能,使处理错误变得简洁而直接。这个特性帮助开发人员快速识别和解决错误,减少了调试代…

NVR添加rtsp流模拟GB28181视频通道

一、海康、大华监控摄像头和硬盘录像机接入GB28181平台配置 1、海康设备接入配置 通过web登录NVR管理系统,进入网络,高级配置界面,填入GB28181相关参数。 将对应项按刚才获取的配置信息填入即可,下面的视频通道的编码ID可以保持…

Aztec.nr:Aztec的隐私智能合约框架——用Noir扩展智能合约功能

1. 引言 前序博客有: Aztec的隐私抽象:在尊重EVM合约开发习惯的情况下实现智能合约隐私 Aztec.nr,为: 面向Aztec应用的,新的,强大的智能合约框架使得开发者可直观管理私有状态基于Noir构建,…

一键集成prometheus监控微服务接口平均响应时长

一、效果展示 二、环境准备 prometheus + grafana环境 参考博文:https://blog.csdn.net/luckywuxn/article/details/129475991 三、导入依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter