目录
摘要
Abstract
Mask R-CNN
网络架构
Backbone
RPN
Proposal Layer
ROIAlign
bbox检测
Mask分割
损失计算
实验复现
总结
摘要
Mask R-CNN是在Faster R-CNN的基础上进行改进的目标检测和实例分割网络。Faster R-CNN主要用于目标检测,输出对象的边界框和类别标签,而Mask R-CNN在Faster R-CNN的基础上增加了像素级分割的能力,能够输出对象的像素级掩码。Mask R-CNN使用了ROI Align层,解决了Faster R-CNN在边界像素对齐方面的问题,从而提高了检测和分割的精度。ROI Align通过双线性插值来避免量化操作,更精确地从特征图中提取对应RoI的富含空间信息的特征,保持空间位置信息,解决了Faster R-CNN中使用的RoI Pooling方法的定位不准确问题 。Mask R-CNN在Faster R-CNN的架构基础上增加了一个并行的掩膜预测分支,在每个RoI上,使用FCN来预测对象的掩膜,使得网络能够更细致地学习物体的空间特征。Mask R-CNN在PASCAL VOC和MS COCO等多个重要的数据集上达到了当时的最佳分割和检测精度。
Abstract
Mask R-CNN is an improved object detection and instance segmentation network based on Faster R-CNN. Faster R-CNN is primarily used for object detection, outputting the bounding boxes and class labels of objects, while Mask R-CNN adds the capability of pixel-level segmentation on the basis of Faster R-CNN, enabling the output of pixel-level masks for objects. Mask R-CNN employs the ROI Align layer, which addresses the issue of boundary pixel alignment in Faster R-CNN, thereby enhancing the precision of detection and segmentation. ROI Align uses bilinear interpolation to avoid quantization operations, extracting features from the feature map that correspond to the RoI with rich spatial information more accurately, preserving spatial location information, and resolving the inaccurate localization issue of the RoI Pooling method used in Faster R-CNN. Mask R-CNN adds a parallel mask prediction branch to the architecture of Faster R-CNN, using an FCN to predict the masks of objects on each RoI, allowing the network to learn the spatial features of objects in more detail. Mask R-CNN has achieved the best segmentation and detection accuracy at the time on several important datasets, including PASCAL VOC and MS COCO.
Mask R-CNN
论文地址:[1703.06870v3] Mask R-CNN
项目地址:Mask R-CNN
Mask R-CNN是一种在有效检测目标的同时输出高质量的实例分割的网络模型,是对Faster R-CNN的扩展,在bbox检测的同时并行地增加一个预测分割掩码的分支。Mask R-CNN就是将物体检测和语义分割结合起来,从而达到了实例分割的效果,该模型效果图如下所示:
在我们学习Mask R-CNN之前,我们需要先对Faster R-CNN有一定的了解,大家可以通过我之前的博客了解。
网络架构
Mask R-CNN网络模型,如下图所示:
Backbone
该模型采用了ResNet101+FPN作为骨干网络进行图像特征提取,选用ResNet提取特征我们已再熟悉不过了,为了增强图像的语义特征,更好地预测不同大小的物体,额外引入了FPN模块。FPN示意图如下图(d)所示:
图(d)中金字塔底部为浅层特征图,金字塔顶部为深层特征图。浅层特征图感受野小,适合检测小目标;深层的特征图感受野大,适合检测大目标。FPN通过融合不同尺度的特征图,使得模型能够同时处理不同大小的目标。
FPN网络结构如下所示:
该网络主要由自底向上的特征提取路径和自顶向下的特征融合路径组成。自底向上的路径是ResNet的正向传播过程,用于提取不同层次的特征图。自顶向下的路径通过上采样和横向连接的方式,将高层特征图的语义信息与低层特征图的空间信息进行融合。
RPN
主要是在骨干网络提取的特征图像中选取候选区域,详细可看Faster R-CNN中的介绍。
Proposal Layer
将RPN选取的候选框作为输入,利用rpn_bbox对选取的anchors进行修正,得到修正后的RoI。然后,舍弃掉修正后边框超过图片大小的anchor,再根据RPN网络,获取score靠前的前6000个RoI。最后,利用非极大抑制的方法获得最终需要进行预测和分割的区域。
ROIAlign
ROIAlign的提出是为了解决Faster R-CNN中RoI Pooling的区域不匹配的问题。
- RoI Pooling
RoI Pooling是Faster R-CNN中必不可少的一步,因为其会产生长度固定的特征向量,有了长度固定的特征向量才能进行Softmax计算分类损失。该方法区域不匹配问题是由于RoI Pooling过程中的取整操作造成的。
例如:输入一张 800×800 的图片,经过一个有5次降采样的卷机网络,得到大小为 25×25 的特征图像。
- 第一次区域不匹配
输入图像的RoI区域大小为 600×500 ,经过网络之后对应的区域为 18.75 × 15.625 ,ROI Pooling采用向下取整的方式,得到RoI区域的特征图像为 18 × 15 。
- 第二次区域不匹配
然后,RoI Pooling将上一步中的特征图像分块,假如需要一个 7 × 7 块,每个块大小为 ,同样进行向下取整,导致每块大小为 2×2 ,即整个RoI区域的特征图像的尺寸为缩小为 14×14 。
上述两次不匹配导致特征图像在横向和纵向上分别产生了4.75和1.625的误差,对于Faster R-CNN进行目标检测而言,几个像素的偏差在视觉上可能微乎其微。但是,对于Mask R-CNN增加了实例分割而言就会严重影响精确度。
- ROIAlign
RoIAlign没有取整操作,可全程使用浮点数。
(1)计算RoI区域的边长,边长不取整;
(2)将RoI区域均匀分成 k × k 个块,每个块的大小不取整;
(3)每个块的值为其最邻近的特征图像的四个值通过双线性插值得到;
假设白框中的交点为特征图像上的点,蓝框为RoI特征图像。将蓝框分为了 7x7 的块,若要计算每个块的值,则需要借助以下公式:
其中u、v分别为某块中心粉点与、的横向距离u,以及与、的纵向距离v。
(4)使用Max Pooling或者Average Pooling得到长度固定的特征向量。
使用RoIAlign的对于准确度的提升还是很明显的,如下图所示:
bbox检测
将RoIAlign输出的 7x7x256 的特征图像拉伸至 1x1x1024 的特征向量,然后分别进行分类和框预测即可。与Faster R-CNN类似,如下图灰色区域所示:
Mask分割
如上图下半部分所示,Mask分支使用传统的FCN图像分割方法,最后生成 28×28×80 的预测掩码结果。
最后得到的结果是软掩码,经过Sigmoid后的(0,1)浮点数。
损失计算
Mask R-CNN在Faster R-CNN的基础上添加了一个用于语义分割的Mask损失函数。
在进行掩码预测时,FCN的分割和预测是同时进行的,即需要预测每个像素属于哪一类。而Mask R-CNN将分类和语义分割任务进行了解耦,即每个类单独的预测一个位置掩码,这种解耦提升了语义分割的效果,如下图所示:
实验复现
本次实验特征提取网络采用预训练的ResNet50,Mask R-CNN以Batch Size=8、学习率为0.08,在COCO2017数据集上训练一轮。
由于资源有限只训练了一轮,由于COCO数据集比较大,最后得到的检测和分割效果还能接受。
数据处理代码如下:
import os
import jsonimport torch
from PIL import Image
import torch.utils.data as data
from pycocotools.coco import COCO
from train_utils import coco_remove_images_without_annotations, convert_coco_poly_maskclass CocoDetection(data.Dataset):"""`MS Coco Detection <https://cocodataset.org/>`_ Dataset.Args:root (string): Root directory where images are downloaded to.dataset (string): train or val.transforms (callable, optional): A function/transform that takes input sample and its target as entryand returns a transformed version."""def __init__(self, root, dataset="train", transforms=None, years="2017"):super(CocoDetection, self).__init__()assert dataset in ["train", "val"], 'dataset must be in ["train", "val"]'anno_file = f"instances_{dataset}{years}.json"assert os.path.exists(root), "file '{}' does not exist.".format(root)self.img_root = os.path.join(root, f"{dataset}{years}")assert os.path.exists(self.img_root), "path '{}' does not exist.".format(self.img_root)self.anno_path = os.path.join(root, "annotations", anno_file)assert os.path.exists(self.anno_path), "file '{}' does not exist.".format(self.anno_path)self.mode = datasetself.transforms = transformsself.coco = COCO(self.anno_path)# 获取coco数据索引与类别名称的关系# 注意在object80中的索引并不是连续的,虽然只有80个类别,但索引还是按照stuff91来排序的data_classes = dict([(v["id"], v["name"]) for k, v in self.coco.cats.items()])max_index = max(data_classes.keys()) # 90# 将缺失的类别名称设置成N/Acoco_classes = {}for k in range(1, max_index + 1):if k in data_classes:coco_classes[k] = data_classes[k]else:coco_classes[k] = "N/A"if dataset == "train":json_str = json.dumps(coco_classes, indent=4)with open("coco91_indices.json", "w") as f:f.write(json_str)self.coco_classes = coco_classesids = list(sorted(self.coco.imgs.keys()))if dataset == "train":# 移除没有目标,或者目标面积非常小的数据valid_ids = coco_remove_images_without_annotations(self.coco, ids)self.ids = valid_idselse:self.ids = idsdef parse_targets(self,img_id: int,coco_targets: list,w: int = None,h: int = None):assert w > 0assert h > 0# 只筛选出单个对象的情况anno = [obj for obj in coco_targets if obj['iscrowd'] == 0]boxes = [obj["bbox"] for obj in anno]# guard against no boxes via resizingboxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)# [xmin, ymin, w, h] -> [xmin, ymin, xmax, ymax]boxes[:, 2:] += boxes[:, :2]boxes[:, 0::2].clamp_(min=0, max=w)boxes[:, 1::2].clamp_(min=0, max=h)classes = [obj["category_id"] for obj in anno]classes = torch.tensor(classes, dtype=torch.int64)area = torch.tensor([obj["area"] for obj in anno])iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])segmentations = [obj["segmentation"] for obj in anno]masks = convert_coco_poly_mask(segmentations, h, w)# 筛选出合法的目标,即x_max>x_min且y_max>y_minkeep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])boxes = boxes[keep]classes = classes[keep]masks = masks[keep]area = area[keep]iscrowd = iscrowd[keep]target = {}target["boxes"] = boxestarget["labels"] = classestarget["masks"] = maskstarget["image_id"] = torch.tensor([img_id])# for conversion to coco apitarget["area"] = areatarget["iscrowd"] = iscrowdreturn targetdef __getitem__(self, index):"""Args:index (int): IndexReturns:tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``."""coco = self.cocoimg_id = self.ids[index]ann_ids = coco.getAnnIds(imgIds=img_id)coco_target = coco.loadAnns(ann_ids)path = coco.loadImgs(img_id)[0]['file_name']img = Image.open(os.path.join(self.img_root, path)).convert('RGB')w, h = img.sizetarget = self.parse_targets(img_id, coco_target, w, h)if self.transforms is not None:img, target = self.transforms(img, target)return img, targetdef __len__(self):return len(self.ids)def get_height_and_width(self, index):coco = self.cocoimg_id = self.ids[index]img_info = coco.loadImgs(img_id)[0]w = img_info["width"]h = img_info["height"]return h, w@staticmethoddef collate_fn(batch):return tuple(zip(*batch))if __name__ == '__main__':train = CocoDetection("/root/autodl-tmp/COCO2017", dataset="train")print(len(train))t = train[0]
模型训练代码如下:
import os
import datetimeimport torch
from torchvision.ops.misc import FrozenBatchNorm2dimport transforms
from network_files import MaskRCNN
from backbone import resnet50_fpn_backbone
from my_dataset_coco import CocoDetection
from my_dataset_voc import VOCInstances
from train_utils import train_eval_utils as utils
from train_utils import GroupedBatchSampler, create_aspect_ratio_groupsdef create_model(num_classes, load_pretrain_weights=True):# 如果GPU显存很小,batch_size不能设置很大,建议将norm_layer设置成FrozenBatchNorm2d(默认是nn.BatchNorm2d)# FrozenBatchNorm2d的功能与BatchNorm2d类似,但参数无法更新# trainable_layers包括['layer4', 'layer3', 'layer2', 'layer1', 'conv1'], 5代表全部训练# backbone = resnet50_fpn_backbone(norm_layer=FrozenBatchNorm2d,# trainable_layers=3)# resnet50 imagenet weights url: https://download.pytorch.org/models/resnet50-0676ba61.pthbackbone = resnet50_fpn_backbone(pretrain_path="./weight/resnet50.pth", trainable_layers=3)model = MaskRCNN(backbone, num_classes=num_classes)if load_pretrain_weights:# coco weights url: "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth"weights_dict = torch.load("./weight/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth", map_location="cpu")for k in list(weights_dict.keys()):if ("box_predictor" in k) or ("mask_fcn_logits" in k):del weights_dict[k]print(model.load_state_dict(weights_dict, strict=False))return modeldef main(args):device = torch.device(args.device if torch.cuda.is_available() else "cpu")print("Using {} device training.".format(device.type))# 用来保存coco_info的文件now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")det_results_file = f"det_results{now}.txt"seg_results_file = f"seg_results{now}.txt"data_transform = {"train": transforms.Compose([transforms.ToTensor(),transforms.RandomHorizontalFlip(0.5)]),"val": transforms.Compose([transforms.ToTensor()])}data_root = args.data_path# load train data set# coco2017 -> annotations -> instances_train2017.jsontrain_dataset = CocoDetection(data_root, "train", data_transform["train"])# VOCdevkit -> VOC2012 -> ImageSets -> Main -> train.txt# train_dataset = VOCInstances(data_root, year="2012", txt_name="train.txt", transforms=data_transform["train"])train_sampler = None# 是否按图片相似高宽比采样图片组成batch# 使用的话能够减小训练时所需GPU显存,默认使用if args.aspect_ratio_group_factor >= 0:train_sampler = torch.utils.data.RandomSampler(train_dataset)# 统计所有图像高宽比例在bins区间中的位置索引group_ids = create_aspect_ratio_groups(train_dataset, k=args.aspect_ratio_group_factor)# 每个batch图片从同一高宽比例区间中取train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)# 注意这里的collate_fn是自定义的,因为读取的数据包括image和targets,不能直接使用默认的方法合成batchbatch_size = args.batch_sizenw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workersprint('Using %g dataloader workers' % nw)if train_sampler:# 如果按照图片高宽比采样图片,dataloader中需要使用batch_samplertrain_data_loader = torch.utils.data.DataLoader(train_dataset,batch_sampler=train_batch_sampler,pin_memory=True,num_workers=nw,collate_fn=train_dataset.collate_fn)else:train_data_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,pin_memory=True,num_workers=nw,collate_fn=train_dataset.collate_fn)# load validation data set# coco2017 -> annotations -> instances_val2017.jsonval_dataset = CocoDetection(data_root, "val", data_transform["val"])# VOCdevkit -> VOC2012 -> ImageSets -> Main -> val.txt# val_dataset = VOCInstances(data_root, year="2012", txt_name="val.txt", transforms=data_transform["val"])val_data_loader = torch.utils.data.DataLoader(val_dataset,batch_size=1,shuffle=False,pin_memory=True,num_workers=nw,collate_fn=train_dataset.collate_fn)# create model num_classes equal background + classesmodel = create_model(num_classes=args.num_classes + 1, load_pretrain_weights=args.pretrain)model.to(device)train_loss = []learning_rate = []val_map = []# define optimizerparams = [p for p in model.parameters() if p.requires_grad]optimizer = torch.optim.SGD(params, lr=args.lr,momentum=args.momentum,weight_decay=args.weight_decay)scaler = torch.cuda.amp.GradScaler() if args.amp else None# learning rate schedulerlr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=args.lr_steps,gamma=args.lr_gamma)# 如果传入resume参数,即上次训练的权重地址,则接着上次的参数训练if args.resume:# If map_location is missing, torch.load will first load the module to CPU# and then copy each parameter to where it was saved,# which would result in all processes on the same machine using the same set of devices.checkpoint = torch.load(args.resume, map_location='cpu') # 读取之前保存的权重文件(包括优化器以及学习率策略)model.load_state_dict(checkpoint['model'])optimizer.load_state_dict(checkpoint['optimizer'])lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])args.start_epoch = checkpoint['epoch'] + 1if args.amp and "scaler" in checkpoint:scaler.load_state_dict(checkpoint["scaler"])for epoch in range(args.start_epoch, args.epochs):# train for one epoch, printing every 50 iterationsmean_loss, lr = utils.train_one_epoch(model, optimizer, train_data_loader,device, epoch, print_freq=50,warmup=True, scaler=scaler)train_loss.append(mean_loss.item())learning_rate.append(lr)# update the learning ratelr_scheduler.step()# evaluate on the test datasetdet_info, seg_info = utils.evaluate(model, val_data_loader, device=device)# write detection into txtwith open(det_results_file, "a") as f:# 写入的数据包括coco指标还有loss和learning rateresult_info = [f"{i:.4f}" for i in det_info + [mean_loss.item()]] + [f"{lr:.6f}"]txt = "epoch:{} {}".format(epoch, ' '.join(result_info))f.write(txt + "\n")# write seg into txtwith open(seg_results_file, "a") as f:# 写入的数据包括coco指标还有loss和learning rateresult_info = [f"{i:.4f}" for i in seg_info + [mean_loss.item()]] + [f"{lr:.6f}"]txt = "epoch:{} {}".format(epoch, ' '.join(result_info))f.write(txt + "\n")val_map.append(det_info[1]) # pascal mAP# save weightssave_files = {'model': model.state_dict(),'optimizer': optimizer.state_dict(),'lr_scheduler': lr_scheduler.state_dict(),'epoch': epoch}if args.amp:save_files["scaler"] = scaler.state_dict()torch.save(save_files, "./save_weights/model_{}.pth".format(epoch))# plot loss and lr curveif len(train_loss) != 0 and len(learning_rate) != 0:from plot_curve import plot_loss_and_lrplot_loss_and_lr(train_loss, learning_rate)# plot mAP curveif len(val_map) != 0:from plot_curve import plot_mapplot_map(val_map)if __name__ == "__main__":import argparseparser = argparse.ArgumentParser(description=__doc__)# 训练设备类型parser.add_argument('--device', default='cuda:0', help='device')# 训练数据集的根目录parser.add_argument('--data-path', default='/root/autodl-tmp/COCO2017', help='dataset')# 检测目标类别数(不包含背景)parser.add_argument('--num-classes', default=90, type=int, help='num_classes')# 文件保存地址parser.add_argument('--output-dir', default='./save_weights', help='path where to save')# 若需要接着上次训练,则指定上次训练保存权重文件地址parser.add_argument('--resume', default='', type=str, help='resume from checkpoint')# 指定接着从哪个epoch数开始训练parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')# 训练的总epoch数parser.add_argument('--epochs', default=3, type=int, metavar='N',help='number of total epochs to run')# 学习率parser.add_argument('--lr', default=0.004, type=float,help='initial learning rate, 0.02 is the default value for training ''on 8 gpus and 2 images_per_gpu')# SGD的momentum参数parser.add_argument('--momentum', default=0.9, type=float, metavar='M',help='momentum')# SGD的weight_decay参数parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,metavar='W', help='weight decay (default: 1e-4)',dest='weight_decay')# 针对torch.optim.lr_scheduler.MultiStepLR的参数parser.add_argument('--lr-steps', default=[16, 22], nargs='+', type=int,help='decrease lr every step-size epochs')# 针对torch.optim.lr_scheduler.MultiStepLR的参数parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')# 训练的batch size(如果内存/GPU显存充裕,建议设置更大)parser.add_argument('--batch_size', default=2, type=int, metavar='N',help='batch size when training.')parser.add_argument('--aspect-ratio-group-factor', default=3, type=int)parser.add_argument("--pretrain", type=bool, default=True, help="load COCO pretrain weights.")# 是否使用混合精度训练(需要GPU支持混合精度)parser.add_argument("--amp", default=False, help="Use torch.cuda.amp for mixed precision training")args = parser.parse_args()print(args)# 检查保存权重文件夹是否存在,不存在则创建if not os.path.exists(args.output_dir):os.makedirs(args.output_dir)main(args)
资源有限,所以只训练了一轮,部分类别的准确度欠佳,训练结果评估如下:
性能评估代码如下:
"""
该脚本用于调用训练好的模型权重去计算验证集/测试集的COCO指标
以及每个类别的mAP(IoU=0.5)
"""import os
import jsonimport torch
from tqdm import tqdm
import numpy as npimport transforms
from backbone import resnet50_fpn_backbone
from network_files import MaskRCNN
from my_dataset_coco import CocoDetection
from my_dataset_voc import VOCInstances
from train_utils import EvalCOCOMetricdef summarize(self, catId=None):"""Compute and display summary metrics for evaluation results.Note this functin can *only* be applied on the default parameter setting"""def _summarize(ap=1, iouThr=None, areaRng='all', maxDets=100):p = self.paramsiStr = ' {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}'titleStr = 'Average Precision' if ap == 1 else 'Average Recall'typeStr = '(AP)' if ap == 1 else '(AR)'iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \if iouThr is None else '{:0.2f}'.format(iouThr)aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]if ap == 1:# dimension of precision: [TxRxKxAxM]s = self.eval['precision']# IoUif iouThr is not None:t = np.where(iouThr == p.iouThrs)[0]s = s[t]if isinstance(catId, int):s = s[:, :, catId, aind, mind]else:s = s[:, :, :, aind, mind]else:# dimension of recall: [TxKxAxM]s = self.eval['recall']if iouThr is not None:t = np.where(iouThr == p.iouThrs)[0]s = s[t]if isinstance(catId, int):s = s[:, catId, aind, mind]else:s = s[:, :, aind, mind]if len(s[s > -1]) == 0:mean_s = -1else:mean_s = np.mean(s[s > -1])print_string = iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s)return mean_s, print_stringstats, print_list = [0] * 12, [""] * 12stats[0], print_list[0] = _summarize(1)stats[1], print_list[1] = _summarize(1, iouThr=.5, maxDets=self.params.maxDets[2])stats[2], print_list[2] = _summarize(1, iouThr=.75, maxDets=self.params.maxDets[2])stats[3], print_list[3] = _summarize(1, areaRng='small', maxDets=self.params.maxDets[2])stats[4], print_list[4] = _summarize(1, areaRng='medium', maxDets=self.params.maxDets[2])stats[5], print_list[5] = _summarize(1, areaRng='large', maxDets=self.params.maxDets[2])stats[6], print_list[6] = _summarize(0, maxDets=self.params.maxDets[0])stats[7], print_list[7] = _summarize(0, maxDets=self.params.maxDets[1])stats[8], print_list[8] = _summarize(0, maxDets=self.params.maxDets[2])stats[9], print_list[9] = _summarize(0, areaRng='small', maxDets=self.params.maxDets[2])stats[10], print_list[10] = _summarize(0, areaRng='medium', maxDets=self.params.maxDets[2])stats[11], print_list[11] = _summarize(0, areaRng='large', maxDets=self.params.maxDets[2])print_info = "\n".join(print_list)if not self.eval:raise Exception('Please run accumulate() first')return stats, print_infodef save_info(coco_evaluator,category_index: dict,save_name: str = "record_mAP.txt"):iou_type = coco_evaluator.params.iouTypeprint(f"IoU metric: {iou_type}")# calculate COCO info for all classescoco_stats, print_coco = summarize(coco_evaluator)# calculate voc info for every classes(IoU=0.5)classes = [v for v in category_index.values() if v != "N/A"]voc_map_info_list = []for i in range(len(classes)):stats, _ = summarize(coco_evaluator, catId=i)voc_map_info_list.append(" {:15}: {}".format(classes[i], stats[1]))print_voc = "\n".join(voc_map_info_list)print(print_voc)# 将验证结果保存至txt文件中with open(save_name, "w") as f:record_lines = ["COCO results:",print_coco,"","mAP(IoU=0.5) for each category:",print_voc]f.write("\n".join(record_lines))def main(parser_data):device = torch.device(parser_data.device if torch.cuda.is_available() else "cpu")print("Using {} device training.".format(device.type))data_transform = {"val": transforms.Compose([transforms.ToTensor()])}# read class_indictlabel_json_path = parser_data.label_json_pathassert os.path.exists(label_json_path), "json file {} dose not exist.".format(label_json_path)with open(label_json_path, 'r') as f:category_index = json.load(f)data_root = parser_data.data_path# 注意这里的collate_fn是自定义的,因为读取的数据包括image和targets,不能直接使用默认的方法合成batchbatch_size = parser_data.batch_sizenw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workersprint('Using %g dataloader workers' % nw)# load validation data setval_dataset = CocoDetection(data_root, "val", data_transform["val"])# VOCdevkit -> VOC2012 -> ImageSets -> Main -> val.txt# val_dataset = VOCInstances(data_root, year="2012", txt_name="val.txt", transforms=data_transform["val"])val_dataset_loader = torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,shuffle=False,pin_memory=True,num_workers=nw,collate_fn=val_dataset.collate_fn)# create modelbackbone = resnet50_fpn_backbone()model = MaskRCNN(backbone, num_classes=args.num_classes + 1)# 载入你自己训练好的模型权重weights_path = parser_data.weights_pathassert os.path.exists(weights_path), "not found {} file.".format(weights_path)model.load_state_dict(torch.load(weights_path, map_location='cpu')['model'])# print(model)model.to(device)# evaluate on the val datasetcpu_device = torch.device("cpu")det_metric = EvalCOCOMetric(val_dataset.coco, "bbox", "det_results.json")seg_metric = EvalCOCOMetric(val_dataset.coco, "segm", "seg_results.json")model.eval()with torch.no_grad():for image, targets in tqdm(val_dataset_loader, desc="validation..."):# 将图片传入指定设备deviceimage = list(img.to(device) for img in image)# inferenceoutputs = model(image)outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]det_metric.update(targets, outputs)seg_metric.update(targets, outputs)det_metric.synchronize_results()seg_metric.synchronize_results()det_metric.evaluate()seg_metric.evaluate()save_info(det_metric.coco_evaluator, category_index, "det_record_mAP.txt")save_info(seg_metric.coco_evaluator, category_index, "seg_record_mAP.txt")if __name__ == "__main__":import argparseparser = argparse.ArgumentParser(description=__doc__)# 使用设备类型parser.add_argument('--device', default='cuda', help='device')# 检测目标类别数(不包含背景)parser.add_argument('--num-classes', type=int, default=90, help='number of classes')# 数据集的根目录parser.add_argument('--data-path', default='/root/autodl-tmp/COCO2017', help='dataset root')# 训练好的权重文件parser.add_argument('--weights-path', default='./save_weights/model_0.pth', type=str, help='training weights')# batch size(set to 1, don't change)parser.add_argument('--batch-size', default=1, type=int, metavar='N',help='batch size when validation.')# 类别索引和类别名称对应关系parser.add_argument('--label-json-path', type=str, default="coco91_indices.json")args = parser.parse_args()main(args)
结果预测
输入图像1:
效果展示1:
输入图像2:
效果展示2:
预测代码如下:
import os
import time
import jsonimport numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
from torchvision import transformsfrom network_files import MaskRCNN
from backbone import resnet50_fpn_backbone
from draw_box_utils import draw_objsdef create_model(num_classes, box_thresh=0.5):backbone = resnet50_fpn_backbone()model = MaskRCNN(backbone,num_classes=num_classes,rpn_score_thresh=box_thresh,box_score_thresh=box_thresh)return modeldef time_synchronized():torch.cuda.synchronize() if torch.cuda.is_available() else Nonereturn time.time()def main():num_classes = 90 # 不包含背景box_thresh = 0.5weights_path = "./save_weights/model_0.pth"img_path = "./street.png"label_json_path = './coco91_indices.json'# get devicesdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))# create modelmodel = create_model(num_classes=num_classes + 1, box_thresh=box_thresh)# load train weightsassert os.path.exists(weights_path), "{} file dose not exist.".format(weights_path)weights_dict = torch.load(weights_path, map_location='cpu')weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dictmodel.load_state_dict(weights_dict)model.to(device)# read class_indictassert os.path.exists(label_json_path), "json file {} dose not exist.".format(label_json_path)with open(label_json_path, 'r') as json_file:category_index = json.load(json_file)# load imageassert os.path.exists(img_path), f"{img_path} does not exits."original_img = Image.open(img_path).convert('RGB')# from pil image to tensor, do not normalize imagedata_transform = transforms.Compose([transforms.ToTensor()])img = data_transform(original_img)# expand batch dimensionimg = torch.unsqueeze(img, dim=0)model.eval() # 进入验证模式with torch.no_grad():# initimg_height, img_width = img.shape[-2:]init_img = torch.zeros((1, 3, img_height, img_width), device=device)model(init_img)t_start = time_synchronized()predictions = model(img.to(device))[0]t_end = time_synchronized()print("inference+NMS time: {}".format(t_end - t_start))predict_boxes = predictions["boxes"].to("cpu").numpy()predict_classes = predictions["labels"].to("cpu").numpy()predict_scores = predictions["scores"].to("cpu").numpy()predict_mask = predictions["masks"].to("cpu").numpy()predict_mask = np.squeeze(predict_mask, axis=1) # [batch, 1, h, w] -> [batch, h, w]if len(predict_boxes) == 0:print("没有检测到任何目标!")returnplot_img = draw_objs(original_img,boxes=predict_boxes,classes=predict_classes,scores=predict_scores,masks=predict_mask,category_index=category_index,line_thickness=3,font='arial.ttf',font_size=20)plt.imshow(plot_img)plt.show()# 保存预测的图片结果plot_img.save("test_result.jpg")if __name__ == '__main__':main()
总结
Mask R-CNN通过引入RoIAlign层和全卷积网络分枝,不仅提高了分割精度,还实现了像素级的掩码输出,极大地推动了目标检测技术的发展。这一突破性工作不仅在COCO等数据集上取得了最佳性能,而且对后续研究产生了深远影响,激发了包括注意力机制、多模态信息融合和小样本学习在内的多种优化策略的研究。尽管Mask R-CNN在速度和参数量方面仍存在挑战,但其未来的优化方向,如:结合强化学习、模型轻量化等,有望进一步提升模型性能,降低计算成本,使其在实时应用和资源受限的设备上更具实用性。