DERT目标检测源码流程图main.py的执行

news/2024/11/17 23:24:30/

DERT目标检测源码流程图main.py的执行

官网预测脚本

补充官网提供的预测部分的代码信息。

from PIL import Image
import requests
import matplotlib.pyplot as pltimport torch
from torch import nn
from torchvision.models import resnet50
import torchvision.transforms as T
torch.set_grad_enabled(False)class DETRdemo(nn.Module):"""Demo DETR implementation.Demo implementation of DETR in minimal number of lines, with thefollowing differences wrt DETR in the paper:* learned positional encoding (instead of sine)* positional encoding is passed at input (instead of attention)* fc bbox predictor (instead of MLP)The model achieves ~40 AP on COCO val5k and runs at ~28 FPS on Tesla V100.Only batch size 1 supported."""def __init__(self, num_classes, hidden_dim=256, nheads=8,num_encoder_layers=6, num_decoder_layers=6):super().__init__()# create ResNet-50 backboneself.backbone = resnet50()del self.backbone.fc# create conversion layerself.conv = nn.Conv2d(2048, hidden_dim, 1)# create a default PyTorch transformerself.transformer = nn.Transformer(hidden_dim, nheads, num_encoder_layers, num_decoder_layers)# prediction heads, one extra class for predicting non-empty slots# note that in baseline DETR linear_bbox layer is 3-layer MLPself.linear_class = nn.Linear(hidden_dim, num_classes + 1)self.linear_bbox = nn.Linear(hidden_dim, 4)# output positional encodings (object queries)self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))# spatial positional encodings# note that in baseline DETR we use sine positional encodingsself.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))def forward(self, inputs):# propagate inputs through ResNet-50 up to avg-pool layerx = self.backbone.conv1(inputs)x = self.backbone.bn1(x)x = self.backbone.relu(x)x = self.backbone.maxpool(x)x = self.backbone.layer1(x)x = self.backbone.layer2(x)x = self.backbone.layer3(x)x = self.backbone.layer4(x)# convert from 2048 to 256 feature planes for the transformerh = self.conv(x)# construct positional encodingsH, W = h.shape[-2:]pos = torch.cat([ # 张量的顺序进行转换self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),], dim=-1).flatten(0, 1).unsqueeze(1)# propagate through the transformerh = self.transformer(pos + 0.1 * h.flatten(2).permute(2, 0, 1),self.query_pos.unsqueeze(1)).transpose(0, 1)# finally project transformer outputs to class labels and bounding boxesreturn {'pred_logits': self.linear_class(h),'pred_boxes': self.linear_bbox(h).sigmoid()}detr = DETRdemo(num_classes=91)
state_dict = torch.hub.load_state_dict_from_url(url='https://dl.fbaipublicfiles.com/detr/detr_demo-da2a99e9.pth',map_location='cpu', check_hash=True)
detr.load_state_dict(state_dict)
detr.eval();# COCO classes
CLASSES = ['N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus','train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A','stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse','sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack','umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis','snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove','skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass','cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich','orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake','chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A','N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard','cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A','book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier','toothbrush'
]# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]# standard PyTorch mean-std input image normalization
transform = T.Compose([T.Resize(800),T.ToTensor(),T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):x_c, y_c, w, h = x.unbind(1)b = [(x_c - 0.5 * w), (y_c - 0.5 * h),(x_c + 0.5 * w), (y_c + 0.5 * h)]return torch.stack(b, dim=1)def rescale_bboxes(out_bbox, size):img_w, img_h = sizeb = box_cxcywh_to_xyxy(out_bbox)b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)return bdef detect(im, model, transform):# mean-std normalize the input image (batch-size: 1)img = transform(im).unsqueeze(0)# demo model only support by default images with aspect ratio between 0.5 and 2# if you want to use images with an aspect ratio outside this range# rescale your image so that the maximum size is at most 1333 for best resultsassert img.shape[-2] <= 1600 and img.shape[-1] <= 1600, 'demo model only supports images up to 1600 pixels on each side'# propagate through the modeloutputs = model(img)# keep only predictions with 0.7+ confidenceprobas = outputs['pred_logits'].softmax(-1)[0, :, :-1]keep = probas.max(-1).values > 0.7# convert boxes from [0; 1] to image scalesbboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)return probas[keep], bboxes_scaledurl = 'http://images.cocodataset.org/val2017/000000039769.jpg'
# url = "./test.jpg"
# im = Image.open(url)
im = Image.open(requests.get(url, stream=True).raw)scores, boxes = detect(im, detr, transform)def plot_results(pil_img, prob, boxes):plt.figure(figsize=(16, 10))plt.imshow(pil_img)ax = plt.gca()for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), COLORS * 100):ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,fill=False, color=c, linewidth=3))cl = p.argmax()text = f'{CLASSES[cl]}: {p[cl]:0.2f}'ax.text(xmin, ymin, text, fontsize=15,bbox=dict(facecolor='yellow', alpha=0.5))plt.axis('off')plt.show()plot_results(im, scores, boxes)

在这里插入图片描述

核心流程图

  1. 整体执行流程概述
  2. 模型构建过程
  3. 前向传播与损失函数

需要代码注释部分可联系,简单原因不在提供代码注释。只关注断点调试得到的流程图信息。

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述


http://www.ppmy.cn/news/1531571.html

相关文章

力扣题解2207

大家好&#xff0c;欢迎来到无限大的频道。 今日继续给大家带来力扣题解。 题目描述&#xff08;中等&#xff09;​&#xff1a; 字符串中最多数目的子序列 给你一个下标从 0 开始的字符串 text 和另一个下标从 0 开始且长度为 2 的字符串 pattern &#xff0c;两者都只包…

【计算机视觉】YoloV8-训练与测试教程

✨ Blog’s 主页: 白乐天_ξ( ✿&#xff1e;◡❛) &#x1f308; 个人Motto&#xff1a;他强任他强&#xff0c;清风拂山冈&#xff01; &#x1f4ab; 欢迎来到我的学习笔记&#xff01; 制作数据集 Labelme 数据集 数据集选用自己标注的&#xff0c;可参考以下&#xff1a…

HTML·第三章课后练习题

采用表格布局完成“CASIO计算器”外观设计&#xff0c;其中表格的每一个单元格均需要设计带边框 <!DOCTYPE html> <html lang"zh"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width…

【R语言】fs 工具功能速查

文件操作 文件操作函数作用file_copy() dir_copy() link_copy()复制文件、目录、链接file_create() dir_create() link_create()创建文件、目录、链接file_delete() dir_delete() link_delete()删除文件、目录、链接file_access() file_exists() dir_exists() link_exists()文…

【Web】御网杯信息安全大赛2024 wp(全)

目录 input_data admin flask 如此多的FLAG 一夜醒来之全国CTF水平提升1000倍&#x1f60b; input_data 访问./.svn后随便翻一翻拿到flag admin dirsearch扫出来 访问./error看出来是java框架 测出来是/admin;/路由打Spring View Manipulation(Java)的SSTI https:/…

大型语言模型(Large Language Models)的介绍

背景 大型语言模型&#xff08;Large Language Models&#xff0c;简称LLMs&#xff09;是一类先进的人工智能模型&#xff0c;它们通过深度学习技术&#xff0c;特别是神经网络&#xff0c;来理解和生成自然语言。这些模型在自然语言处理&#xff08;NLP&#xff09;领域中扮…

0基础学前端 day5

JavaScript 前端学习指南 JavaScript是当今Web开发的核心语言之一。作为前端开发的基石&#xff0c;掌握JavaScript有助于开发者构建动态、交互丰富的网页应用程序。本文将详细介绍JavaScript的基本语法、DOM和BOM的使用、接口请求、最新的ES6特性&#xff0c;以及一些重要的概…

3-1.Android Fragment 之创建 Fragment

Fragment Fragment 可以视为 Activity 的一个片段&#xff0c;它具有自己的生命周期和接收事件的能力&#xff0c;它有以下特点 Fragment 依赖于 Activity&#xff0c;不能独立存在&#xff0c;Fragment 的生命周期受 Activity 的生命周期影响 Fragment 将 Activity 的 UI 和…