RT-DETR+Sort 实现目标跟踪

news/2024/10/21 6:31:57/

在前一篇博客中,博主介绍了利用YOLOv8Sort算法实现目标跟踪,在今天这篇博客中,博主将利用RT-DETR算法与Sort算法相结合,从而实现目标跟踪。。

这里博主依旧是采用ONNX格式的模型文件来执行推理过程,由于Sort算法是基于检测懂得目标跟踪方法,因此我们只需要获取到检测结果即可,代码如下:

import onnxruntime as ort
sess= ort.InferenceSession("detr.onnx", None)
output = sess.run(output_names=['labels', 'boxes', 'scores'],#output_names=None,input_feed={'images': im_data.data.numpy(), "orig_target_sizes": size.data.numpy()},)
cls,outbox,score=output

获得的检测结果如下,分别是预测类别,检测框的xywh以及其置信度

在这里插入图片描述

需要注意的是,DETR类目标检测算法作为一种端到端的目标检测方法,并不需要NMS等后处理过程,但它依旧需要对结果进行筛选,这里直接通过对置信度进行筛选即可

outbox=np.squeeze(outbox)
boxindex=np.where(score>thrh)#np.where方法可以返回符合条件的值的坐标
outbox=outbox[boxindex[1]]

在这里插入图片描述

随后的过程,便是与YOLO+Sort方法一致了。

关于Sort算法,我们缕清一个思路即可,即Sort算法是用来做跟踪的,在这个算法中的追踪器是卡尔曼滤波追踪器。

在这里插入图片描述

RT-DETR+Sort算法完整代码如下:

import cv2
import imageio
import numpy as np
from pathlib import Path
from filterpy.kalman import KalmanFilter
from scipy.optimize import linear_sum_assignmentfrom tasks.yolo_track.detr_sort import scale_boxdef linear_assignment(cost_matrix):x, y = linear_sum_assignment(cost_matrix)return np.array(list(zip(x, y)))def iou_batch(bb_test, bb_gt):"""From SORT: Computes IOU between two bboxes in the form [x1,y1,x2,y2]"""bb_gt = np.expand_dims(bb_gt, 0)bb_test = np.expand_dims(bb_test, 1)xx1 = np.maximum(bb_test[..., 0], bb_gt[..., 0])yy1 = np.maximum(bb_test[..., 1], bb_gt[..., 1])xx2 = np.minimum(bb_test[..., 2], bb_gt[..., 2])yy2 = np.minimum(bb_test[..., 3], bb_gt[..., 3])w = np.maximum(0., xx2 - xx1)h = np.maximum(0., yy2 - yy1)wh = w * ho = wh / ((bb_test[..., 2] - bb_test[..., 0]) * (bb_test[..., 3] - bb_test[..., 1])+ (bb_gt[..., 2] - bb_gt[..., 0]) * (bb_gt[..., 3] - bb_gt[..., 1]) - wh)return(o)def convert_bbox_to_z(bbox):"""Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form[x,y,s,r] where x,y is the centre of the box and s is the scale/area and r isthe aspect ratio"""w = bbox[2] - bbox[0]h = bbox[3] - bbox[1]x = bbox[0] + w/2.y = bbox[1] + h/2.s = w * h    #scale is just arear = w / float(h)return np.array([x, y, s, r]).reshape((4, 1))def convert_x_to_bbox(x):"""Takes a bounding box in the centre form [x,y,s,r] and returns it in the form[x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right"""w = np.sqrt(x[2] * x[3])h = x[2] / wreturn np.array([x[0]-w/2.,x[1]-h/2.,x[0]+w/2.,x[1]+h/2.]).reshape((1,4))class KalmanBoxTracker(object):"""This class represents the internal state of individual tracked objects observed as bbox."""count = 0def __init__(self,bbox):"""Initialises a tracker using initial bounding box."""#define constant velocity modelself.kf = KalmanFilter(dim_x=7, dim_z=4)self.kf.F = np.array([[1,0,0,0,1,0,0],[0,1,0,0,0,1,0],[0,0,1,0,0,0,1],[0,0,0,1,0,0,0],[0,0,0,0,1,0,0],[0,0,0,0,0,1,0],[0,0,0,0,0,0,1]])self.kf.H = np.array([[1,0,0,0,0,0,0],[0,1,0,0,0,0,0],[0,0,1,0,0,0,0],[0,0,0,1,0,0,0]])self.kf.R[2:,2:] *= 10.self.kf.P[4:,4:] *= 1000. #give high uncertainty to the unobservable initial velocitiesself.kf.P *= 10.self.kf.Q[-1,-1] *= 0.01self.kf.Q[4:,4:] *= 0.01self.kf.x[:4] = convert_bbox_to_z(bbox)self.time_since_update = 0self.id = KalmanBoxTracker.countKalmanBoxTracker.count += 1self.history = []self.hits = 0self.hit_streak = 0self.age = 0def update(self,bbox):"""Updates the state vector with observed bbox."""self.time_since_update = 0self.history = []self.hits += 1self.hit_streak += 1self.kf.update(convert_bbox_to_z(bbox))def predict(self):"""Advances the state vector and returns the predicted bounding box estimate."""if(self.kf.x[6]+self.kf.x[2]<=0):self.kf.x[6] *= 0.0self.kf.predict()self.age += 1if(self.time_since_update>0):self.hit_streak = 0self.time_since_update += 1self.history.append(convert_x_to_bbox(self.kf.x))#记录历史坐标,这个坐标是预测的坐标return self.history[-1]#历史坐标的最后一个即当前的预测位置def get_state(self):"""Returns the current bounding box estimate."""return convert_x_to_bbox(self.kf.x)def associate_detections_to_tracks(detections,trackers,iou_threshold = 0.3):"""Assigns detections to tracked object (both represented as bounding boxes)Returns 3 lists of matches, unmatched_detections and unmatched_trackers"""if(len(trackers)==0):return np.empty((0,2),dtype=int), np.arange(len(detections)), np.empty((0,5),dtype=int)#计算两两间的交并比,调用linear_assignment进行匹配iou_matrix = iou_batch(detections, trackers)if min(iou_matrix.shape) > 0:a = (iou_matrix > iou_threshold).astype(np.int32)if a.sum(1).max() == 1 and a.sum(0).max() == 1:matched_indices = np.stack(np.where(a), axis=1)else:matched_indices = linear_assignment(-iou_matrix)else:matched_indices = np.empty(shape=(0,2))#记录未匹配的检测框及轨迹unmatched_detections = []for d, det in enumerate(detections):if(d not in matched_indices[:,0]):unmatched_detections.append(d)unmatched_trackers = []for t, trk in enumerate(trackers):if(t not in matched_indices[:,1]):unmatched_trackers.append(t)#过滤掉IoU低的匹配matches = []for m in matched_indices:if(iou_matrix[m[0], m[1]]<iou_threshold):unmatched_detections.append(m[0])unmatched_trackers.append(m[1])else:matches.append(m.reshape(1,2))if(len(matches)==0):matches = np.empty((0,2),dtype=int)else:matches = np.concatenate(matches,axis=0)return matches, np.array(unmatched_detections), np.array(unmatched_trackers)class Sort(object):def __init__(self, max_age=1, min_hits=3, iou_threshold=0.3):"""Sets key parameters for SORT"""self.max_age = max_age #连续预测的最大次数,就是放在self.trackers跟踪器列表中的框用卡尔曼滤波器连续预测位置的最大次数self.min_hits = min_hits #最小更新的次数,就是放在self.trackers跟踪器列表中的框与检测框匹配上,# 然后调用卡尔曼滤波器类中的update函数的最小次数,min_hits不设置为0是因为第一次检测到的目标不用跟踪,# 只需要加入到跟踪器列表中,不会显示,这个值不能设大,一般就是1,表示如果连续两帧都检测到目标,self.iou_threshold = iou_threshold#IOU阈值self.trackers = []#存储追踪器self.frame_count = 0#读取的帧数量def update(self, dets=np.empty((0, 5))):"""Params:dets - a numpy array of detections in the format [[x1,y1,x2,y2,score],[x1,y1,x2,y2,score],...]Requires: this method must be called once for each frame even with empty detections (use np.empty((0, 5)) for frames without detections).Returns the a similar array, where the last column is the object ID.NOTE: The number of objects returned may differ from the number of detections provided."""self.frame_count += 1# get predicted locations from existing trackers.trks = np.zeros((len(self.trackers), 5))#空的[]ret = []for t, trk in enumerate(trks):#第一帧时,里面没有东西,直接跳过pos = self.trackers[t].predict()[0]trk[:] = [pos[0], pos[1], pos[2], pos[3], 0]#numpy.ma.masked_invalid屏蔽出现无效值的数组(NaN或inf;numpy.ma.compress_rows压缩包含掩码值的2-D 数组的整行。trks = np.ma.compress_rows(np.ma.masked_invalid(trks))matched, unmatched_dets, unmatched_trks = associate_detections_to_tracks(dets, trks, self.iou_threshold)#关联检测框与轨迹,在第一帧时,只有unmatched_dets# update matched trackers with assigned detectionsfor m in matched:#如果有匹配上的则利用刚刚检测的结果来更新,即用于卡尔曼滤波预测,第一帧时不执行self.trackers[m[1]].update(dets[m[0], :])# create and initialise new trackers for unmatched detectionsfor i in unmatched_dets:#第一帧时,没有匹配的轨迹,则创建对应的轨迹,一个一个的进行创建trk = KalmanBoxTracker(dets[i,:])self.trackers.append(trk)i = len(self.trackers)#自后向前遍历,仅返回在当前帧出现且命中周期大于self.min_hits(除非跟踪刚开始)的跟踪结果;如果未命中时间大于self.max_age则删除跟踪器。for trk in reversed(self.trackers):d = trk.get_state()[0]if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits): #hit_streak:忽略目标初始的若干帧ret.append(np.concatenate((d,[trk.id+1])).reshape(1,-1)) # +1 as MOT benchmark requires positivei -= 1if(trk.time_since_update > self.max_age):self.trackers.pop(i)if(len(ret)>0):return np.concatenate(ret)return np.empty((0,5))
def scale_boxes(input_shape, boxes, shape):# Rescale boxes (xyxy) from input_shape to shapegain = min(input_shape[0] / shape[0], input_shape[1] / shape[1])  # gain  = old / newpad = (input_shape[1] - shape[1] * gain) / 2, (input_shape[0] - shape[0] * gain) / 2  # wh paddingboxes[..., [0, 2]] =boxes[..., [0, 2]]- pad[0]  # x paddingboxes[..., [1, 3]] =boxes[..., [1, 3]] - pad[1]  # y paddingboxes[..., :4] =boxes[..., :4] / gainboxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1])  # x1, x2boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0])  # y1, y2return boxesif __name__ == '__main__':import onnxruntime as ortimport cv2from PIL import Imageimport torchfrom torchvision.transforms import ToTensormot_tracker = Sort(max_age=1, min_hits=3, iou_threshold=0.3) #create instance of the SORT trackercolours = np.random.rand(32, 3) * 255sess= ort.InferenceSession("detr.onnx", None)input_path="video.mp4"cap = cv2.VideoCapture(input_path)size = torch.tensor([[640, 640]])images=[]thrh = 0.6while True:_, image_ori = cap.read()if image_ori is None:breakw,h,c=image_ori.shapeimage=cv2.resize(image_ori,(640,640))img = Image.fromarray(image)img = img.convert('RGB')im_data = ToTensor()(img)[None]output = sess.run(output_names=['labels', 'boxes', 'scores'],#output_names=None,input_feed={'images': im_data.data.numpy(), "orig_target_sizes": size.data.numpy()},)cls,outbox,score=output#outbox=scale_boxes((h,w),outbox,(640,640))outbox=np.squeeze(outbox)#outbox = outbox[np.lexsort(outbox[:,::-1].T)]boxindex=np.where(score>thrh)outbox=outbox[boxindex[1]]trackers = mot_tracker.update(outbox)for d in trackers:d = d.astype(np.int32)cv2.rectangle(image, (d[0], d[1]), (d[2], d[3]), colours[d[4]%32,:], 1)cv2.putText(image,str(d[4]),(d[0], d[1]),3,1,(255,0,0))images.append(image)cv2.waitKey(50)imageio.mimsave('output.gif',images,fps=30)cap.release()

http://www.ppmy.cn/news/1519481.html

相关文章

学习指纹浏览器 处理美团mtgsig1.2 环境检测

声明: 本文章中所有内容仅供学习交流使用&#xff0c;不用于其他任何目的&#xff0c;抓包内容、敏感网址、数据接口等均已做脱敏处理&#xff0c;严禁用于商业用途和非法用途&#xff0c;否则由此产生的一切后果均与作者无关&#xff01; 有相关问题请第一时间头像私信联系我…

【STM32】IIC

超级常见的外设通信方式&#xff0c;一般叫做I方C。 大部分图片来源&#xff1a;正点原子HAL库课程 专栏目录&#xff1a;记录自己的嵌入式学习之路-CSDN博客 目录 1 基本概念 1.1 总线结构 1.2 IIC协议 1.3 软件模拟IIC逻辑 2 AT24C02 2.1 设备地址与…

【python 第七篇章】类

在Python中&#xff0c;类&#xff08;Class&#xff09;是一种用于定义对象&#xff08;Object&#xff09;的蓝图或模板。通过类&#xff0c;我们可以创建具有相同属性和方法的对象。Python的类提供了面向对象编程&#xff08;OOP&#xff09;的所有基本功能&#xff1a;封装…

SpringBatch4升级SpringBatch5踩坑指南

SpringBatch5使用JDK17作为基线版本&#xff0c;若项目未升级&#xff0c;亦无需升级springbatch。 报错1:JobParameter类型问题 报错内容&#xff1a; java.util.Map.Entry<java.lang.String,org.springframework.batch.core.JobParameter<?>>无法转换为java.u…

RabbitMQ 常见问题与故障排查

目录 前言 常见错误与解决方案 1. 连接失败 2. 队列阻塞 3. 消息丢失 4. 消费者不消费 5. 资源耗尽 日志分析 1. 配置 RabbitMQ 日志 2.日志文件位置 3. 日志分析工具 4. 分析日志文件 5. 常见日志问题及解决方案 Docker中日志分析 1. 查看 RabbitMQ 日志 2. 获…

elasticsearch安装在服务器并进行向量检索

服务器安装elasticsearch 安装Elasticsearch的步骤通常包括以下几个阶段&#xff1a; 导入Elasticsearch公钥。 创建Elasticsearch仓库。 安装Elasticsearch。 启动Elasticsearch服务。 配置Elasticsearch开机自启。 以下是针对基于Debian/Ubuntu系统的安装示例&#xf…

OpenCV绘图函数(6)绘制椭圆函数ellipse()的使用

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 画出一个简单的或粗的椭圆弧或者填充一个椭圆扇形。 函数 cv::ellipse 使用更多的参数可以画出椭圆轮廓、填充的椭圆、椭圆弧或填充的椭圆扇形。…

SQL慢查询优化方式

目录 一、SQL语句优化 1.避免使用 SELECT * &#xff0c;而是具体字段 2.避免使用 % 开头的 LIKE 的查询 3.避免使用子查询&#xff0c;使用JOIN 4.使用EXISTS代替IN 5.使用LIMIT 1优化查询 6.使用批量插入、优化INSERT操作 7.其他方式 二、SQL索引优化 1.在查询条件…