【零基础保姆级教程】MMDetection3训练输出Precision/Recall/F1-Score指标

news/2024/10/4 6:59:35/

最近为了跑对比试验,MMDetection这一框架整合的算法较多,故博主训练它并留下记录,若有疑问等欢迎评论、指正。

基本信息:博主在完成训练流程后,保留了整个过程的权重文件在worke_dirs/路径下,名称epoch_1.pth-epoch_150epoch.pth。

给出公式原理:

当然可以。Precision(精确率)、Recall(召回率)和F1分数都是用来评估分类模型性能的重要指标,特别是在不平衡数据集的情况下。

### Precision(精确率)
精确率是指所有被预测为正类的样本中实际为正类的比例。公式如下:
\[ \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} \]

其中:
- TP(True Positives):真正例,被模型正确地预测为正类的样本数。
- FP(False Positives):假正例,被模型错误地预测为正类的样本数。

### Recall(召回率)
召回率是指所有实际为正类的样本中被正确预测为正类的比例。公式如下:
\[ \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} \]

其中:
- FN(False Negatives):假负例,实际为正类但被模型错误地预测为负类的样本数。

### F1 Score(F1分数)
F1分数是精确率和召回率的调和平均值,它试图同时优化精确率和召回率。公式如下:
\[ F1 = 2 \cdot \frac{\text{Precision} \cdot \text{Recall}}{\text{Precision} + \text{Recall}} \]

当精确率和召回率相等时,F1分数取得最大值;当其中一个非常小而另一个较大时,F1分数会比较低。

一、输出单个轮次权重的指标:

输入命令:

python tools/test.py configs/ssd/ssdlite_mobilenetv2-scratch_8xb24-600e_coco.py work_dirs/epoch_xx.pth --out=result.pkl

此处命令的xx需改为实际的数字,即可对你训练出的权重进行测试,测试结果会输出COCO指标与一个result.pkl文件,这个文件可用于生成precision/recall/f1。

打开文件tools/analysis_tools/confusion_matrix.py

在文件后加入代码

TP = np.diag(confusion_matrix)FP = np.sum(confusion_matrix, axis=0) - TPFN = np.sum(confusion_matrix, axis=1) - TPprecision = TP / (TP + FP)recall = TP / (TP + FN)average_precision = np.mean(precision)average_recall = np.mean(recall)f1 = 2 * (precision * recall) / (precision + recall)print('AP:', average_precision)print('AR:', average_recall)print('F1:', f1)print('Precision', precision)print('Recall', recall)output_file_path = os.path.join(save_dir, 'PRF1.txt')with open(output_file_path, 'a') as output_file:output_file.write({precision:.5f}   {recall:.5f}   {f1:.5f}\n')

运行命令

python tools/analysis_tools/confusion_matrix.py configs/ssd/ssdlite_mobilenetv2-scratch_8xb24-600e_coco.py result.pkl results/ --score-thr 0.5 

即可生成对应一个epoch权重的指标。

二、输出整个轮次权重的指标

对tools/test.py做修改,全文覆盖为博主代码

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import warnings
from copy import deepcopyfrom mmengine import ConfigDict
from mmengine.config import Config, DictAction
from mmengine.runner import Runnerfrom mmdet.engine.hooks.utils import trigger_visualization_hook
from mmdet.evaluation import DumpDetResults
from mmdet.registry import RUNNERS
from mmdet.utils import setup_cache_size_limit_of_dynamo# TODO: support fuse_conv_bn and format_only
def parse_args():parser = argparse.ArgumentParser(description='MMDet test (and eval) a model')parser.add_argument('config', help='test config file path')parser.add_argument('checkpoint', help='checkpoint file')parser.add_argument('--work-dir',help='the directory to save the file containing evaluation metrics')parser.add_argument('--out',type=str,help='dump predictions to a pickle file for offline evaluation')parser.add_argument('--show', action='store_true', help='show prediction results')parser.add_argument('--show-dir',help='directory where painted images will be saved. ''If specified, it will be automatically saved ''to the work_dir/timestamp/show_dir')parser.add_argument('--wait-time', type=float, default=2, help='the interval of show (s)')parser.add_argument('--cfg-options',nargs='+',action=DictAction,help='override some settings in the used config, the key-value pair ''in xxx=yyy format will be merged into config file. If the value to ''be overwritten is a list, it should be like key="[a,b]" or key=a,b ''It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ''Note that the quotation marks are necessary and that no white space ''is allowed.')parser.add_argument('--launcher',choices=['none', 'pytorch', 'slurm', 'mpi'],default='none',help='job launcher')parser.add_argument('--tta', action='store_true')# When using PyTorch version >= 2.0.0, the `torch.distributed.launch`# will pass the `--local-rank` parameter to `tools/train.py` instead# of `--local_rank`.parser.add_argument('--local_rank', '--local-rank', type=int, default=0)args = parser.parse_args()if 'LOCAL_RANK' not in os.environ:os.environ['LOCAL_RANK'] = str(args.local_rank)return argsdef get_checkpoint_files(directory):"""从指定目录中获取所有的.pth文件路径列表."""checkpoint_files = []for root, dirs, files in os.walk(directory):for file in files:if file.endswith('.pth'):checkpoint_files.append(osp.join(root, file))return checkpoint_filesdef main():# 默认参数config_path = r'D:\mmdetection-main\configs\ssd\ssdlite_mobilenetv2-scratch_8xb24-600e_coco.py'  # 替换为你的配置文件路径checkpoint_directory = r'D:\mmdetection-main\work_dirs'work_dir = r'D:\mmdetection-main\work_dirs'  # 替换为你的.pkl文件生成目录路径# 获取.pth文件路径列表checkpoint_files = get_checkpoint_files(checkpoint_directory)for checkpoint_file in checkpoint_files:# 设置参数args = argparse.Namespace(config=config_path,checkpoint=checkpoint_file,work_dir=work_dir,out=None,show=False,show_dir=None,wait_time=2,cfg_options=None,launcher='none',tta=False,local_rank=0)# 生成输出文件名,与.pth同名但是是.pkl格式out_file = osp.splitext(checkpoint_file)[0] + '.pkl'args.out = out_file# Reduce the number of repeated compilations and improve# testing speed.setup_cache_size_limit_of_dynamo()# load configcfg = Config.fromfile(args.config)cfg.launcher = args.launcherif args.cfg_options is not None:cfg.merge_from_dict(args.cfg_options)# work_dir is determined in this priority: CLI > segment in file > filenameif args.work_dir is not None:# update configs according to CLI args if args.work_dir is not Nonecfg.work_dir = args.work_direlif cfg.get('work_dir', None) is None:# use config filename as default work_dir if cfg.work_dir is Nonecfg.work_dir = osp.join('./work_dirs',osp.splitext(osp.basename(args.config))[0])# 更新模型权重文件路径cfg.load_from = args.checkpointif args.show or args.show_dir:cfg = trigger_visualization_hook(cfg, args)if args.tta:if 'tta_model' not in cfg:warnings.warn('Cannot find ``tta_model`` in config, ''we will set it as default.')cfg.tta_model = dict(type='DetTTAModel',tta_cfg=dict(nms=dict(type='nms', iou_threshold=0.5), max_per_img=100))if 'tta_pipeline' not in cfg:warnings.warn('Cannot find ``tta_pipeline`` in config, ''we will set it as default.')test_data_cfg = cfg.test_dataloader.datasetwhile 'dataset' in test_data_cfg:test_data_cfg = test_data_cfg['dataset']cfg.tta_pipeline = deepcopy(test_data_cfg.pipeline)flip_tta = dict(type='TestTimeAug',transforms=[[dict(type='RandomFlip', prob=1.),dict(type='RandomFlip', prob=0.)],[dict(type='PackDetInputs',meta_keys=('img_id', 'img_path', 'ori_shape','img_shape', 'scale_factor', 'flip','flip_direction'))],])cfg.tta_pipeline[-1] = flip_ttacfg.model = ConfigDict(**cfg.tta_model, module=cfg.model)cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline# build the runner from configif 'runner_type' not in cfg:# build the default runnerrunner = Runner.from_cfg(cfg)else:# build customized runner from the registry# if 'runner_type' is set in the cfgrunner = RUNNERS.build(cfg)# add `DumpResults` dummy metricif args.out is not None:assert args.out.endswith(('.pkl', '.pickle')), \'The dump file must be a pkl file.'runner.test_evaluator.metrics.append(DumpDetResults(out_file_path=args.out))# start testingrunner.test()if __name__ == '__main__':main()

直接运行该文件即可。

其中,博主将一些代码写死在main()函数中,使用时需修改,如下。

    # 默认参数
    config_path = r'D:\mmdetection-main\configs\ssd\ssdlite_mobilenetv2-scratch_8xb24-600e_coco.py'  # 替换为你的配置文件路径
    checkpoint_directory = r'D:\mmdetection-main\work_dirs' # 替换为你的.pth文件存放目录路径
    work_dir = r'D:\mmdetection-main\work_dirs'  # 替换为你的.pkl文件生成目录路径

对tools\analysis_tools\confusion_matrix.py做修改,全文覆盖为博主代码

import argparse
import glob
import osimport matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MultipleLocator
from mmcv.ops import nms
from mmengine import Config, DictAction
from mmengine.fileio import load
from mmengine.registry import init_default_scope
from mmengine.utils import ProgressBarfrom mmdet.evaluation import bbox_overlaps
from mmdet.registry import DATASETS
from mmdet.utils import replace_cfg_vals, update_data_rootdef parse_args():# 这个函数不再需要,因为我们将直接在main函数中使用硬编码的参数passdef calculate_confusion_matrix(dataset,results,score_thr=0,nms_iou_thr=None,tp_iou_thr=0.5):num_classes = len(dataset.metainfo['classes'])confusion_matrix = np.zeros(shape=[num_classes + 1, num_classes + 1])assert len(dataset) == len(results)prog_bar = ProgressBar(len(results))for idx, per_img_res in enumerate(results):res_bboxes = per_img_res['pred_instances']gts = dataset.get_data_info(idx)['instances']analyze_per_img_dets(confusion_matrix, gts, res_bboxes, score_thr,tp_iou_thr, nms_iou_thr)prog_bar.update()return confusion_matrixdef analyze_per_img_dets(confusion_matrix,gts,result,score_thr=0,tp_iou_thr=0.5,nms_iou_thr=None):true_positives = np.zeros(len(gts))gt_bboxes = []gt_labels = []for gt in gts:gt_bboxes.append(gt['bbox'])gt_labels.append(gt['bbox_label'])gt_bboxes = np.array(gt_bboxes)gt_labels = np.array(gt_labels)unique_label = np.unique(result['labels'].numpy())for det_label in unique_label:mask = (result['labels'] == det_label)det_bboxes = result['bboxes'][mask].numpy()det_scores = result['scores'][mask].numpy()if nms_iou_thr:det_bboxes, _ = nms(det_bboxes, det_scores, nms_iou_thr, score_threshold=score_thr)ious = bbox_overlaps(det_bboxes[:, :4], gt_bboxes)for i, score in enumerate(det_scores):det_match = 0if score >= score_thr:for j, gt_label in enumerate(gt_labels):if ious[i, j] >= tp_iou_thr:det_match += 1if gt_label == det_label:true_positives[j] += 1  # TPconfusion_matrix[gt_label, det_label] += 1if det_match == 0:  # BG FPconfusion_matrix[-1, det_label] += 1for num_tp, gt_label in zip(true_positives, gt_labels):if num_tp == 0:  # FNconfusion_matrix[gt_label, -1] += 1def plot_confusion_matrix(confusion_matrix,labels,save_dir=None,show=True,title='Normalized Confusion Matrix',color_theme='plasma'):# normalize the confusion matrixper_label_sums = confusion_matrix.sum(axis=1)[:, np.newaxis]confusion_matrix = \confusion_matrix.astype(np.float32) / per_label_sums * 100num_classes = len(labels)fig, ax = plt.subplots(figsize=(0.5 * num_classes, 0.5 * num_classes * 0.8), dpi=180)cmap = plt.get_cmap(color_theme)im = ax.imshow(confusion_matrix, cmap=cmap)plt.colorbar(mappable=im, ax=ax)title_font = {'weight': 'bold', 'size': 12}ax.set_title(title, fontdict=title_font)label_font = {'size': 10}plt.ylabel('Ground Truth Label', fontdict=label_font)plt.xlabel('Prediction Label', fontdict=label_font)# draw locatorxmajor_locator = MultipleLocator(1)xminor_locator = MultipleLocator(0.5)ax.xaxis.set_major_locator(xmajor_locator)ax.xaxis.set_minor_locator(xminor_locator)ymajor_locator = MultipleLocator(1)yminor_locator = MultipleLocator(0.5)ax.yaxis.set_major_locator(ymajor_locator)ax.yaxis.set_minor_locator(yminor_locator)# draw gridax.grid(True, which='minor', linestyle='-')# draw labelax.set_xticks(np.arange(num_classes))ax.set_yticks(np.arange(num_classes))ax.set_xticklabels(labels)ax.set_yticklabels(labels)ax.tick_params(axis='x', bottom=False, top=True, labelbottom=False, labeltop=True)plt.setp(ax.get_xticklabels(), rotation=45, ha='left', rotation_mode='anchor')# draw confution matrix valuefor i in range(num_classes):for j in range(num_classes):ax.text(j,i,'{}%'.format(int(confusion_matrix[i,j]) if not np.isnan(confusion_matrix[i, j]) else -1),ha='center',va='center',color='w',size=7)ax.set_ylim(len(confusion_matrix) - 0.5, -0.5)  # matplotlib>3.1.1fig.tight_layout()if save_dir is not None:plt.savefig(os.path.join(save_dir, 'confusion_matrix.png'), format='png')if show:plt.show()def main(config=None, prediction_path=None, save_dir=None, show=True, color_theme='plasma', score_thr=0.3, tp_iou_thr=0.5, nms_iou_thr=None, cfg_options=None):if config is None or prediction_path is None or save_dir is None:raise ValueError("config, prediction_path, and save_dir must be provided.")cfg = Config.fromfile(config)# replace the ${key} with the value of cfg.keycfg = replace_cfg_vals(cfg)# update data root according to MMYOLO_DATASETSupdate_data_root(cfg)if cfg_options is not None:cfg.merge_from_dict(cfg_options)init_default_scope(cfg.get('default_scope', 'mmdet'))if not os.path.exists(save_dir):os.makedirs(save_dir)dataset = DATASETS.build(cfg.test_dataloader.dataset)results = load(prediction_path)confusion_matrix = calculate_confusion_matrix(dataset, results, score_thr, nms_iou_thr, tp_iou_thr)TP = np.diag(confusion_matrix)FP = np.sum(confusion_matrix, axis=0) - TPFN = np.sum(confusion_matrix, axis=1) - TPprecision = TP / (TP + FP)recall = TP / (TP + FN)average_precision = np.mean(precision)average_recall = np.mean(recall)f1 = 2 * (precision[0] * recall[0]) / (precision[0] + recall[0])print('AP:', average_precision)print('AR:', average_recall)print('F1:', f1)print('Precision', precision[0])print('Recall', recall[0])#print('TP:', TP)#print('FP:', FP)#print('FN', FN)output_file_path = os.path.join(save_dir, 'PRF1.txt')with open(output_file_path, 'a') as output_file:output_file.write(f'{prediction_path}    {precision[0]:.5f}   {recall[0]:.5f}   {f1:.5f}\n')if __name__ == '__main__':config = r'D:\mmdetection-main\configs\ssd\ssdlite_mobilenetv2-scratch_8xb24-600e_coco.py'save_dir = r'D:\mmdetection-main\results'def numerical_sort(value):filename = os.path.basename(value)parts = filename.split('result')if len(parts) > 1:number_part = parts[1].split('.')[0]try:return int(number_part)except ValueError:return float('inf')else:return float('inf')# 获取预测结果文件夹下所有以 'result' 开头并按数字顺序排列的.pkl文件prediction_files = sorted(glob.glob(r'D:\lkx\mmdetection-jiexialaidouyongzhege\mmdetection-main\work_dirs\epoch_*.pkl'), key=numerical_sort)print(config)print(prediction_files)print(save_dir)for prediction_path in prediction_files:main(config=config, prediction_path=prediction_path, save_dir=save_dir)

其中,

    config = r'D:\mmdetection-main\configs\ssd\ssdlite_mobilenetv2-scratch_8xb24-600e_coco.py' # 修改为你的配置文件路径
    save_dir = r'D:\mmdetection-main\results' # 修改为你的指标存放路径

def main(config=None, prediction_path=None, save_dir=None, show=True, color_theme='plasma', score_thr=0.3, tp_iou_thr=0.5, nms_iou_thr=None, cfg_options=None):

如上。这里的参数需要手动设置。

该文件也是直接运行即可。

三、将输出的.txt指标文件转换为xml格式
import pandas as pd# 读取文本文件
with open('PRF1.txt', 'r') as file:lines = file.readlines()# 处理每一行数据
data = []
for line in lines:line = line.strip()if line:row = line.split()data.append(row)# 创建DataFrame对象
df = pd.DataFrame(data,columns=['epoch', 'Precision', 'Recall', 'F1'])# 保存为Excel文件
df.to_excel('PRF1.xlsx', index=False)

更多文章产出中,主打简洁和准确,欢迎关注我,共同探讨!


http://www.ppmy.cn/news/1533796.html

相关文章

在使用 Docker 时,用户可能会遇到各种常见的错误和问题

在使用 Docker 时,用户可能会遇到各种常见的错误和问题。以下是一些需要注意的常见错误及其可能的解决方案: 1. 权限问题 在 Linux 系统上运行 Docker 命令时,可能会遇到权限不足的问题。解决这个问题通常有两种方法: 使用 sud…

CORE MVC 过滤器 (筛选器)

MVC FrameWork MVCFramework MVC Core 过滤器 分 同步、异步 1、 授权筛选器 IAuthorizationFilter,IAsyncAuthorizationFilter 管道中运行的第一类筛选器,用来确定发出请求的用户是否有权限发出当前请求 2、资源筛选器 IResourceFilter ,…

【折半查找】

目录 一. 折半查找的概念二. 折半查找的过程三. 折半查找的代码实现四. 折半查找的性能分析 \quad 一. 折半查找的概念 \quad 必须有序 \quad 二. 折半查找的过程 \quad \quad 三. 折半查找的代码实现 \quad 背下来 \quad 四. 折半查找的性能分析 \quad 记住 比较的是层数 …

Python 复制PDF中的页面

操作PDF文档时,复制其中的指定页面可以帮助我们从PDF文件中提取特定信息,如文本、图表或数据等,以便在其他文档中使用。复制PDF页面也可以实现在不同文件中提取页面,以创建一个新的综合文档。 本文将介绍如何使用Python 在同一文档…

矩阵系统源码搭建的具体步骤,支持oem,源码搭建

一、前期准备 明确需求 确定矩阵系统的具体用途,例如是用于社交媒体管理、电商营销还是其他领域。梳理所需的功能模块,如多账号管理、内容发布、数据分析等。 技术选型 选择适合的编程语言,如 Python、Java、Node.js 等。确定数据库类型&…

ClickHouse 的 MergeTree 引擎有哪些性能优势?

ClickHouse 的 MergeTree 引擎是其最核心的表引擎之一,具有以下性能优势: 1. 高吞吐量的数据写入:MergeTree 引擎将数据以不可变的片段形式写入磁盘,这些片段会定期通过后台线程合并,优化存储并提高查询性能。 2. 主…

【openwrt-21.02】T750 openwrt switch划分VLAN之后网口插拔状态异常问题分析及解决方案

Openwrt版本 NAME="OpenWrt" VERSION="21.02-SNAPSHOT" ID="openwrt" ID_LIKE="lede openwrt" PRETTY_NAME="OpenWrt 21.02-SNAPSHOT" VERSION_ID="21.02-snapshot" HOME_URL="https://openwrt.org/" …

【深度学习】—线性回归 线性回归的基本元素 线性模型 损失函数 解析解 随机梯度下降

【深度学习】— 线性回归 线性回归的基本元素 线性模型 损失函数 解析解 随机梯度下降 线性回归线性回归的基本元素 线性模型损失函数解析解随机梯度下降小批量随机梯度下降梯度下降算法的详细步骤解释公式 线性回归 回归(regression)是能为⼀个或多个⾃…