pytorch retinanet目标检测

news/2025/3/14 16:48:17/

本文使用的是pytorch1.7.1,使用Pytorch提供的预训练模型,使用这个预训练模型,我们可以检测COCO数据集中超过80种物体。

RetinaNet的输入格式

输入图像的格式为[C, H, W],即(channels, height, and width),我们也需要提供一个batch size。batch size指一次处理多少张图像。所以输入图像格式为[N, C, H, W]。同时,图像的像素值要在0-1之间。

RetinaNet的输出格式

它输出一个列表包括一个字典,其包含结果张量。格式为List[Dict[Tensor]]。这个Dict包括以下keys:

boxes (FloatTensor[N, 4]):被预测的boxes是[x1, y1, x2, y2]格式

labels (Int64Tensor[N]):每个图片的预测标签

scores:(Tensor[N]):每个预测的得分

coco_names.py

COCO_INSTANCE_CATEGORY_NAMES = ['__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus','train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign','parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow','elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A','handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball','kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket','bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl','banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza','donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table','N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone','microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book','clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

detect_utils.py

为了得到输出,我们需要一个简单的函数predict(),它接收4个输入参数,image、model、device、detection_threshold

import torchvision.transforms as transforms
import cv2
import numpy as np
import torch
from coco_names import COCO_INSTANCE_CATEGORY_NAMES as coco_names#不同类别的框不同的颜色 为91个类生成随机的RGBtuple
COLORS = np.random.uniform(0, 255, size=(len(coco_names), 3))
# define the torchvision image transforms 把图片变成张量
transform = transforms.Compose([transforms.ToTensor(),
])def predict(image, model, device, detection_threshold):# transform the image to tensorimage = transform(image).to(device)image = image.unsqueeze(0) # add a batch dimensionwith torch.no_grad():outputs = model(image) # get the predictions on the image# get all the scoresscores = list(outputs[0]['scores'].detach().cpu().numpy())# index of those scores which are above a certain thresholdthresholded_preds_inidices = [scores.index(i) for i in scores if i > detection_threshold]# get all the predicted bounding boxesbboxes = outputs[0]['boxes'].detach().cpu().numpy()# get boxes above the threshold scoreboxes = bboxes[np.array(scores) >= detection_threshold].astype(np.int32)# get all the predicited class nameslabels = outputs[0]['labels'].cpu().numpy()pred_classes = [coco_names[labels[i]] for i in thresholded_preds_inidices]return boxes, pred_classesdef draw_boxes(boxes, classes, image):for i, box in enumerate(boxes):color = COLORS[coco_names.index(classes[i])]cv2.rectangle(image,(int(box[0]), int(box[1])),(int(box[2]), int(box[3])),color, 2)cv2.putText(image, classes[i], (int(box[0]), int(box[1]-5)),cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2, lineType=cv2.LINE_AA)return image


def predict(image, model, device, detection_threshold):

首先把图像传入RetinaNet模型中,并得到输出

with torch.no_grad():

outputs = model(image) 

得到输出字典中所有的得分

scores = list(outputs[0]['scores'].detach().cpu().numpy())

提取大于threshold的对应的索引

thresholded_preds_inidices = [scores.index(i) for i in scores if i > detection_threshold]

提取输出字典中的所有候选框

bboxes = outputs[0]['boxes'].detach().cpu().numpy()

过滤出大于threshold的候选框

boxes = bboxes[np.array(scores) >= detection_threshold].astype(np.int32)

def draw_boxes(boxes, classes, image):

boxes是前面过滤提取出来的候选框

classes是分类

image是在上面画框和标上分类名

detect_images.py

import torchvision
import torch
import argparse
import cv2
import detect_utils
import numpy as np
from PIL import Image# construct the argument parser
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', help='path to input image/video')
parser.add_argument('-m', '--min-size', dest='min_size', default=800, help='minimum input size for the RetinaNet network')
parser.add_argument('-t', '--threshold', default=0.6, type=float,help='minimum confidence score for detection')
args = vars(parser.parse_args())
print('USING:')
print(f"Minimum image size: {args['min_size']}")
print(f"Confidence threshold: {args['threshold']}")# download or load the model from disk
model = torchvision.models.detection.retinanet_resnet50_fpn(pretrained=True, min_size=args['min_size'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# load the model onto the computation device
model.eval().to(device)image = Image.open(args['input']).convert('RGB')
# a NumPy copy for OpenCV functions
image_array = np.array(image)
# convert to OpenCV BGR color format
image_array = cv2.cvtColor(image_array, cv2.COLOR_RGB2BGR)# get the bounding boxes and class labels
boxes, classes = detect_utils.predict(image, model, device, args['threshold'])
# get the final image
result = detect_utils.draw_boxes(boxes, classes, image_array)cv2.imshow('Image', result)
cv2.waitKey(0)
save_name = f"{args['input'].split('/')[-1].split('.')[0]}_{args['min_size']}_t{int(args['threshold']*100)}"
cv2.imwrite(f"outputs/{save_name}.jpg", result)

the function is not implemented
Rebuild the library with Windows, GTK+ 2.x or Cocoa support. 

卸载重新下

 

https://download.pytorch.org/whl/torch_stable.html


http://www.ppmy.cn/news/928056.html

相关文章

【论文】RetinaNet

论文地址:https://arxiv.org/pdf/1708.02002.pdfhttps://arxiv.org/pdf/1708.02002.pdf 题目:密集目标检测的焦点损失 图 1. 我们提出了一种新的损失,我们称之为 Focal Loss,它在标准交叉熵标准中添加了一个因子 (1 - p t ) 的γ…

Retina屏而被图处理

屏幕上一张清晰的图片 肉眼在屏幕上看到图片的清晰度由三个因素决定,一是图片像素本身是否精细,二是屏幕分辨率,三是屏幕大小。 我们来逐步分析它们之间的关系: 屏幕分辨率 屏幕分辨率也就是设备分辨率,设备像素&#…

photoShop支持retina显示屏

没有实验性功能的看过来,可以强制ps的所有UI跟随系统缩放,操作系统级别的,无论PS有没有内置这个功能,唯一不好的是,编辑区图片享受不到4K的好处,算是4K显示器的迂回解决办法吧: 1.打开注册表Win…

Android:生成XML文件

[Student{clazzj2307,name张三丰,sex男,age25},Student{clazzj2307,name周芷若,sex女,age23}] List<Student> students new ArrayList<>(); // 创建xml文件 public void createXML(View view){ //判断要生成的数据是否存在 if(students ! null…

vue,vuex,通过store维护很多个页面要用到的公共变量

将获取订单数量的方法放在 Vuex 的 actions 中是比较好的做法。这样可以确保数据的获取和维护逻辑集中在一个地方&#xff0c;并且可以在多个组件中共享这个数据。 在 Vuex 的 actions 中&#xff0c;你可以定义一个异步的方法来获取订单数量&#xff0c;并将获取到的数量保存…

android手机助手---界面篇(1)

以前没写过blog&#xff0c;这次要单独完成一个android手机助手&#xff0c;试试写一下记录这个项目的点点滴滴&#xff0c;写的不好了大家轻喷。 我做东西喜欢从界面开始然后一点点填充功能。以下是第一天的界面截图(界面元素来源与网络&#xff0c;如果有懂设计的大哥欢迎联…

android手机助手---小部分功能篇(2)

界面还没有认真设计&#xff0c;只添加了几个小功能&#xff0c;截图&#xff0c;手机型号&#xff0c;MAC地址&#xff0c;操作系统版本。

android手机电视助手是如何控制电视的呢?

以下是我和一同事的聊天记录&#xff1a; XXX 上午 9:19:26 使用了两种方式&#xff0c;一个是手机端执行“adb shell input keyevent 键值” 命令操作&#xff0c;另一种方式是电视端静默安装遥控器插件 XXX 上午 9:20:26 还有第三种方法&#xff0c;通过httpserver 呵呵…