联合目标检测与图像分类提升数据不平衡场景下的准确率

news/2024/12/21 20:48:46/

联合目标检测与图像分类提升数据不平衡场景下的准确率

在一些数据不平衡的场景下,使用单一的目标检测模型很难达到99%的准确率。为了优化这一问题,适当将其拆解为目标检测模型图像分类模型的组合,可以更有效地控制最终效果,尤其是在添加焦点损失(focal loss)、调整超参数和数据预处理无效的情况下。以下是具体的实现方式及联合两个模型的推理代码。

整体功能概述

这段代码的主要功能包括:

  1. 加载目标检测分类模型:使用两个 Ultralytics YOLO(YOLOv8/YOLOv11均可) 模型进行目标检测分类
  2. 处理图像:遍历指定输入文件夹中的所有图像,进行目标检测分类
  3. 绘制检测框和分类标签:在图像上绘制检测到的对象的边界框,并在框上方添加分类名称和置信度。
  4. 可选保存裁剪的对象图像:根据设置,裁剪检测到的对象区域并保存为单独的图像文件,文件名包含类别名称、置信度和坐标信息(便于调试)。

实现细节

1. 加载模型

代码加载了两个 YOLO 模型:

  • 目标检测模型:一个单一类别的 YOLO 模型,用于检测主体对象。
  • 图像分类模型:一个多类别的 YOLO 模型,用于对检测到的对象进行分类

2. 处理图像

脚本处理输入文件夹中的每一张图像,步骤如下:

  • 目标检测:使用目标检测模型检测图像中的对象。
  • 裁剪检测到的对象:根据检测到的边界框坐标,裁剪出感兴趣的区域。
  • 图像分类:对裁剪出的对象区域进行分类
  • 数据增强或欠采样:根据任务需求,对裁剪出的子图像进行数据增强或欠采样,以平衡数据集。

3. 绘制检测框和标签

对于每一个检测到的对象,脚本会:

  • 在图像上绘制一个边界框。
  • 在边界框上方添加分类名称和置信度标签。

4. 保存裁剪的对象图像

可选地,脚本会保存裁剪出的对象图像,文件名包含以下信息:

  • 分类名称
  • 置信度
  • 边界框坐标

这对于调试和分析特定的检测结果非常有帮助。

推理代码

import os
import cv2
import numpy as np
from pathlib import Path
from ultralytics import YOLO
import randomdef generate_random_color_from_name(name):"""根据类别名生成可重复的颜色。"""random.seed(name)  # 使用类别名作为随机种子return tuple(random.randint(0, 255) for _ in range(3))def generate_class_colors(names):"""为每个类别生成一个固定的颜色。"""class_colors = {}for class_name in names:class_colors[class_name] = generate_random_color_from_name(class_name)return class_colorsdef draw_box_on_image(image, box, color=(0, 255, 0), thickness=2):"""在图像上绘制检测框。"""x1, y1, x2, y2 = map(int, box)cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness)def add_classification_to_box(image, box, class_name, confidence, color=(0, 255, 0)):"""在边界框上方添加分类名称和置信度。"""x1, y1, x2, y2 = map(int, box)label = f"{class_name}: {confidence:.2f}"cv2.putText(image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2, cv2.LINE_AA)def save_cropped_object(image, box, cls_class_name, confidence, output_folder, image_name):"""将裁剪的对象区域保存为图像到子文件夹中,文件名包含类别名、置信度和坐标。"""x1, y1, x2, y2 = map(int, box)cropped_img = image[y1:y2, x1:x2]# 为当前图像创建一个以图像文件名命名的子文件夹image_subfolder = Path(output_folder) / Path(image_name).stemimage_subfolder.mkdir(parents=True, exist_ok=True)# 为裁剪的对象创建文件名(class_name_confidence_x1_y1_x2_y2.jpg)# 确保置信度格式安全,使用两位小数,并用下划线分隔cropped_img_name = f"{cls_class_name}_{confidence:.2f}_{x1}_{y1}_{x2}_{y2}.jpg"cropped_img_path = image_subfolder / cropped_img_namecv2.imwrite(str(cropped_img_path), cropped_img)print(f"已保存裁剪对象: {cropped_img_path}")def process_image_with_detection_and_classification(model_det, model_cls, img_path, names, class_colors, output_folder, save_cropped=False, detection_size=1280, classification_size=640):"""处理单张图像:执行对象检测,分类每个对象,并返回处理后的图像。:param model_det: 检测模型:param model_cls: 分类模型:param img_path: 图像路径:param names: 类别名称列表:param class_colors: 类别颜色映射字典:param output_folder: 输出文件夹路径:param save_cropped: 是否保存裁剪的对象图像:param detection_size: 检测模型输入图像大小:param classification_size: 分类模型输入图像大小:return: 处理后的图像"""img = cv2.imread(str(img_path))if img is None:print(f"无法读取图像: {img_path}")return None# 创建图像副本用于绘制(不修改原始图像)img_copy = img.copy()# 执行对象检测results_det = model_det.predict(str(img_path), imgsz=detection_size, conf=0.25, iou=0.45)# 处理每个检测结果(每个检测框)for r in results_det:boxes = r.boxes.xyxy.cpu().numpy()  # xyxy 格式classes = r.boxes.cls.cpu().numpy()confidences = r.boxes.conf.cpu().numpy()for box, cls_id, confidence in zip(boxes, classes, confidences):# 检测模型的类别名det_class_name = names[int(cls_id)]# 使用检测到的类别名对应的颜色(该颜色是全局唯一的)color = class_colors.get(det_class_name, (255, 255, 255))# 裁剪对象区域x1, y1, x2, y2 = map(int, box)object_region = img[y1:y2, x1:x2]# 将对象区域调整为分类模型的输入大小object_region = cv2.resize(object_region, (classification_size, classification_size))# 执行分类results_cls = model_cls.predict(object_region, imgsz=classification_size)for result in results_cls:try:# 获取Top1预测结果classification_confidence = result.probs.cpu().numpy().top1conftop1_index = result.probs.top1cls_class_name = names[top1_index]# 根据分类结果的类别名设置颜色final_color = class_colors.get(cls_class_name, color)add_classification_to_box(img_copy, box, cls_class_name, classification_confidence, color=final_color)# 如果启用了保存裁剪对象,则保存if save_cropped:save_cropped_object(img, box, cls_class_name, classification_confidence, output_folder, img_path.name)except Exception as e:print(f"分类时出错: {e}")# 在图像副本上绘制检测框draw_box_on_image(img_copy, box, color=color)return img_copydef process_images(model_det, model_cls, input_folder, output_folder, names, class_colors, save_cropped=False, detection_size=1280, classification_size=640):"""处理输入文件夹中的图像,执行对象检测和分类,并保存处理后的图像。:param model_det: 检测模型:param model_cls: 分类模型:param input_folder: 输入文件夹路径:param output_folder: 输出文件夹路径:param names: 类别名称列表:param class_colors: 类别颜色映射字典:param save_cropped: 是否保存裁剪的对象图像:param detection_size: 检测模型输入图像大小:param classification_size: 分类模型输入图像大小"""Path(output_folder).mkdir(parents=True, exist_ok=True)image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.webp']for ext in image_extensions:for img_path in Path(input_folder).glob(ext):print(f"正在处理: {img_path}")processed_img = process_image_with_detection_and_classification(model_det, model_cls, img_path, names, class_colors, output_folder, save_cropped, detection_size, classification_size)if processed_img is not None:output_image_path = Path(output_folder) / f"{img_path.stem}_with_boxes_and_classification.jpg"cv2.imwrite(str(output_image_path), processed_img)print(f"已保存处理后的图像: {output_image_path}")else:print(f"跳过图像: {img_path} (无法处理)")if __name__ == '__main__':# 设置是否保存裁剪的对象图像(默认不保存)SAVE_CROPPED = True  # 设置为 True 以启用保存裁剪对象# 加载检测和分类模型model_det = YOLO('runs/device_train/exp9/weights/best.pt')model_cls = YOLO('runs/cls_99.4%_exp14/weights/best.pt')# 设置输入和输出文件夹路径input_folder = 'test1'output_folder = 'infer-1216'# 获取类别名(用于生成一致的类别颜色映射)# 这里使用一张全白的图像来获取类别名black_image = 255 * np.ones((224, 224, 3), dtype=np.uint8)results = model_cls.predict(source=black_image)name_dict = results[0].namesnames = list(name_dict.values())# 只在这里生成一次类别颜色映射class_colors = generate_class_colors(names)# 开始处理图像process_images(model_det, model_cls, input_folder, output_folder,names, class_colors,save_cropped=SAVE_CROPPED,detection_size=1280,classification_size=224)

执行完后的结果
在这里插入图片描述

下面贴一下目标检测和图像分类的ultralytics的训练代码

目标检测训练代码

注意把single_cls=False改成True,变成单类训练

# nohup python -m torch.distributed.launch --nproc_per_node=4 --master_port=25643 det_train.py > output-lane-1212.txt 2>&1 &
# nohup python -m torch.distributed.launch --nproc_per_node=5 --master_port=25698 det_train.py > output-lane-1212.txt 2>&1 &
from ultralytics import YOLOif __name__ == '__main__':# 加载模型model = YOLO("checkpoints/yolo11l.pt")  # 使用预训练权重训练# 训练参数 ----------------------------------------------------------------------------------------------model.train(data='/home/lizhijun/01.det/ultralytics-8.3.23/datasets/device_1212_yolo_without_vdd/config.yaml',epochs=150,  # (int) 训练的周期数patience=50,  # (int) 等待无明显改善以进行早期停止的周期数batch=16,  # (int) 每批次的图像数量(-1 为自动批处理)imgsz=1280,  # (int) 输入图像的大小,整数或w,hsave=True,  # (bool) 保存训练检查点和预测结果save_period=-1,  # (int) 每x周期保存检查点(如果小于1则禁用)cache=False,  # (bool) True/ram、磁盘或False。使用缓存加载数据device='1,2,3,5',  # (int | str | list, optional) 运行的设备,例如 cuda device=0 或 device=0,1,2,3 或 device=cpuworkers=8,  # (int) 数据加载的工作线程数(每个DDP进程)project='runs/device_train',  # (str, optional) 项目名称name='exp',  # (str, optional) 实验名称,结果保存在'project/name'目录下exist_ok=False,  # (bool) 是否覆盖现有实验pretrained=True,  # (bool | str) 是否使用预训练模型(bool),或从中加载权重的模型(str)optimizer='auto',  # (str) 要使用的优化器,选择=[SGD,Adam,Adamax,AdamW,NAdam,RAdam,RMSProp,auto]verbose=True,  # (bool) 是否打印详细输出seed=0,  # (int) 用于可重复性的随机种子deterministic=True,  # (bool) 是否启用确定性模式single_cls=False,  # (bool) 将多类数据训练为单类rect=False,  # (bool) 如果mode='train',则进行矩形训练,如果mode='val',则进行矩形验证cos_lr=True,  # (bool) 使用余弦学习率调度器close_mosaic=10,  # (int) 在最后几个周期禁用马赛克增强resume=False,  # (bool) 从上一个检查点恢复训练amp=True,  # (bool) 自动混合精度(AMP)训练,选择=[True, False],True运行AMP检查fraction=1.0,  # (float) 要训练的数据集分数(默认为1.0,训练集中的所有图像)profile=False,  # (bool) 在训练期间为记录器启用ONNX和TensorRT速度freeze= None,  # (int | list, 可选) 在训练期间冻结前 n 层,或冻结层索引列表。# 超参数 ----------------------------------------------------------------------------------------------lr0=0.01,  # (float) 初始学习率(例如,SGD=1E-2,Adam=1E-3)lrf=0.01,  # (float) 最终学习率(lr0 * lrf)momentum=0.937,  # (float) SGD动量/Adam beta1weight_decay=0.0005,  # (float) 优化器权重衰减 5e-4warmup_epochs=3.0,  # (float) 预热周期(分数可用)warmup_momentum=0.8,  # (float) 预热初始动量warmup_bias_lr=0.1,  # (float) 预热初始偏置学习率box=6,  # (float) 盒损失增益cls=1.5,  # (float) 类别损失增益(与像素比例)dfl=1.5,  # (float) dfl损失增益pose=12.0,  # (float) 姿势损失增益kobj=1.0,  # (float) 关键点对象损失增益label_smoothing=0.05,  # (float) 标签平滑(分数)nbs=64,  # (int) 名义批量大小hsv_h=0.015,  # (float) 图像HSV-Hue增强(分数)hsv_s=0.7,  # (float) 图像HSV-Saturation增强(分数)hsv_v=0.4,  # (float) 图像HSV-Value增强(分数)degrees=90.0,  # (float) 图像旋转(+/- deg)translate=0.5,  # (float) 图像平移(+/- 分数)scale=0.5,  # (float) 图像缩放(+/- 增益)shear=0.4,  # (float) 图像剪切(+/- deg)perspective=0.0,  # (float) 图像透视(+/- 分数),范围为0-0.001flipud=0.5,  # (float) 图像上下翻转(概率)fliplr=0.5,  # (float) 图像左右翻转(概率)mosaic=1.0,  # (float) 图像马赛克(概率)mixup=0.0,  # (float) 图像混合(概率)copy_paste=0.0,  # (float) 分割复制-粘贴(概率))

图像分类训练代码

from ultralytics import YOLOmodel = YOLO("checkpoints/yolo11l-cls.pt")
model.train(data='/home/lizhijun/01.det/ultralytics-8.3.23/datasets/device_cls_merge_manual_with_21w_1218_train_val_224_truncate_grid_110%', project='runs/cls_train',  # (str, optional) 项目名称name='exp',  # (str, optional) 实验名称,结果保存在'project/name'目录下epochs=20, batch=1024,device='1,2,3,5',erasing=0.0,crop_fraction=1.0,augment=False,auto_augment=False,hsv_h=0.015,  # (float) 图像HSV-Hue增强(分数)hsv_s=0.7,  # (float) 图像HSV-Saturation增强(分数)hsv_v=0.4,  # (float) 图像HSV-Value增强(分数)degrees=0.0,  # (float) 图像旋转(+/- deg)translate=0.0,  # (float) 图像平移(+/- 分数)scale=0.0,  # (float) 图像缩放(+/- 增益)shear=0.0,  # (float) 图像剪切(+/- deg)perspective=0.0,  # (float) 图像透视(+/- 分数),范围为0-0.001flipud=0.5,  # (float) 图像上下翻转(概率)fliplr=0.5,  # (float) 图像左右翻转(概率)mosaic=1.0,  # (float) 图像马赛克(概率)mixup=0.0)  # (float) 图像混合(概率))

http://www.ppmy.cn/news/1557014.html

相关文章

Python面试常见问题及答案10

1. 问题:如何在Python中对列表进行排序? 答案: 可以使用列表的sort()方法,它会直接修改原始列表。例如: my_list [3, 1, 4, 1, 5, 9, 2, 6, 5, 3] my_list.sort() print(my_list)也可以使用sorted()函数&#xff0c…

Python 写的 《监控视频存储计算器》

代码: import tkinter as tk from tkinter import ttk import math from tkinter.font import Fontclass StorageCalculator:def __init__(self, root):self.root rootself.root.title("监控视频存储计算器")self.root.geometry("600x800")s…

01背包:模板题+实战题

一、01背包的定义 我们有一个背包,背包的容积有限,最多只能装下总体积为V的物品。现在给定我们N个物品,第i个物品的体积vi,对应的价值是wi( 1 ≤ i ≤ N 1 \leq i \leq N 1≤i≤N)。每个物品有且仅有一个。…

【JavaWeb】Ajax

目录 一、什么是Ajax? 二、同步与异步 三、Ajax工作原理 四、Ajax实现步骤 五、Ajax应用场景 六、Ajax常见问题 1.缓存问题 2.跨域问题 3.请求超时与网络异常 4.取消请求 七、常见Ajax三种请求方式 1.jQuery请求 2.Axios请求 3.Fetch请求 一、什么是A…

mac编译ijkplayer遇到问题

问题:./init-android.sh git version 2.44.0 pull ffmpeg base : command not founde.sh: line 2: : command not founde.sh: line 5: : command not founde.sh: line 6: tools/pull-repo-base.sh: line 9: syntax error near unexpected token elif ools/pull-re…

【第九节】Git 服务器搭建

目录 前言 一、 使用裸存储库搭建 Git 服务器 1.1 安装 Git 1.2 创建裸存储库 1.3 配置 SSH 访问 1.4 克隆仓库 二、 使用 GitLab 搭建 Git 服务器 2.1 安装 GitLab 2.2 配置 GitLab 2.3 创建项目 2.4 生成 SSH 密钥 2.5 添加 SSH Key 三、 使用 GitLab 管理项目 …

智源大模型通用算子库FlagGems四大能力升级 持续赋能AI系统开源生态

FlagGems是由智源研究院于2024年6月推出的面向多种AI芯片的开源大模型通用算子库。FlagGems使用Triton语言开发,在Triton生态开源开放的基础上,为多种AI芯片提供开源、统一、高效的算子层生态接入方案。FlagGems沿着统一的中间语言、统一的算子接口和统一…

条款24:若所有参数皆需类型转换,请为此采用非成员函数

条款24:若所有参数皆需类型转换,请为此采用非成员函数 设计一个表示有理数的类时,允许从整数隐式转换为有理数是有用的: class Rational { public:Rational(int numerator 0, // 该构造函数没有explicit限制;int denominator …