开源地址:
PaddleDetection/configs/mot at release/2.3 · PaddlePaddle/PaddleDetection · GitHub
百度飞浆集成了多目标跟踪的多种算法,地址:
PaddleDetection/configs/mot at release/2.3 · PaddlePaddle/PaddleDetection · GitHub
deepsort:
jde
farimot:
本人测试结果如下,后续可能继续跟踪跟进。
本机代码:运行ok:
PaddleDetection-release-2.3
环境,py37
测试入口类:
tools/infer_mot.py
测试结果:有漏检,
奇怪的地方:
如果读取的是视频文件,先用ffmpeg转为图片,然后排序,读取图片列表,
直接读取图片就可以把?
cap = cv2.VideoCapture(self.video_file)
电脑没有安装ffmpeg,所以把程序改了一下,直接读取文件夹的图片:
def _load_video_images(self):if self.frame_rate == -1:# if frame_rate is not set for video, use cv2.VideoCapturecap = cv2.VideoCapture(self.video_file)self.frame_rate = int(cap.get(cv2.CAP_PROP_FPS))extension = self.video_file.split('.')[-1]output_path = self.video_file.replace('.{}'.format(extension), '')# frames_path = video2frames(self.video_file, output_path,# self.frame_rate)self.video_frames = natsorted(glob.glob(os.path.join(output_path, '*.jpg')))self.video_length = len(self.video_frames)logger.info('Length of the video: {:d} frames.'.format(self.video_length))ct = 0records = []for image in self.video_frames:assert image != '' and os.path.isfile(image), \"Image {} not found".format(image)if self.sample_num > 0 and ct >= self.sample_num:breakrec = {'im_id': np.array([ct]), 'im_file': image}if self.keep_ori_im:rec.update({'keep_ori_im': 1})self._imid2path[ct] = imagect += 1records.append(rec)assert len(records) > 0, "No image file found"return records
改后入口类:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport os
import sys# add python path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)import warningswarnings.filterwarnings('ignore')import paddle
from paddle.distributed import ParallelEnv
from ppdet.core.workspace import load_config, merge_config
from ppdet.engine import Tracker
from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.cli import ArgsParserfrom ppdet.utils.logger import setup_loggerlogger = setup_logger('train')def parse_args():parser = ArgsParser()parser.add_argument('--config', type=str, default="../configs/mot/fairmot/fairmot_dla34_30e_576x320.yml", help='Video name for tracking.')parser.add_argument('--video_file', type=str, default="1.mp4", help='Video name for tracking.')parser.add_argument('--frame_rate', type=int, default=-1, help='Video frame rate for tracking.')parser.add_argument("--image_dir", type=str, default=None, help="Directory for images to perform inference on.")parser.add_argument("--det_results_dir", type=str, default='', help="Directory name for detection results.")parser.add_argument('--output_dir', type=str, default='output', help='Directory name for output tracking results.')parser.add_argument('--save_images', default=False, help='Save tracking results (image).')parser.add_argument('--save_videos', default=False, help='Save tracking results (video).')parser.add_argument('--show_image', default=True, help='Show tracking results (image).')parser.add_argument('--scaled', type=bool, default=False, help="Whether coords after detector outputs are scaled, False in JDE YOLOv3 ""True in general detector.")parser.add_argument("--draw_threshold", type=float, default=0.5, help="Threshold to reserve the result for visualization.")args = parser.parse_args()return argsdef run(FLAGS, cfg):# build Trackertracker = Tracker(cfg, mode='test')# load weightsif cfg.architecture in ['DeepSORT']:if cfg.det_weights != 'None':tracker.load_weights_sde(cfg.det_weights, cfg.reid_weights)else:tracker.load_weights_sde(None, cfg.reid_weights)else:tracker.load_weights_jde(cfg.weights)# inferencetracker.mot_predict(video_file=FLAGS.video_file, frame_rate=FLAGS.frame_rate, image_dir=FLAGS.image_dir, data_type=cfg.metric.lower(), model_type=cfg.architecture, output_dir=FLAGS.output_dir,save_images=FLAGS.save_images, save_videos=FLAGS.save_videos, show_image=FLAGS.show_image, scaled=FLAGS.scaled, det_results_dir=FLAGS.det_results_dir, draw_threshold=FLAGS.draw_threshold)if __name__ == '__main__':FLAGS = parse_args()cfg = load_config(FLAGS.config)merge_config(FLAGS.opt)check_config(cfg)check_gpu(cfg.use_gpu)check_version()place = 'gpu:{}'.format(ParallelEnv().dev_id) if cfg.use_gpu else 'cpu'place = paddle.set_device(place)run(FLAGS, cfg)