【目标检测】模型验证:K-Fold 交叉验证

embedded/2025/2/7 11:36:26/

K-Fold 交叉验证

  • 1、引言
    • 1.1 K 折交叉验证概述
  • 2、配置
    • 2.1 数据集
    • 2.2 安装包
  • 3、 实战
    • 3.1 生成物体检测数据集的特征向量
    • 3.2 K 折数据集拆分
    • 3.3 保存记录
    • 3.4 使用 K 折数据分割训练YOLO
  • 4、总结

1、引言

我们将利用YOLO 检测格式和关键的Python 库(如 sklearn、pandas 和 PyYaml),完成必要的设置、生成特征向量的过程以及 K-Fold 数据集拆分的执行。

1.1 K 折交叉验证概述

无论你的项目涉及水果检测数据集还是自定义数据源,都可以使用 K 折交叉验证,
以提高项目的可靠性和稳健性。

书说简短,闲言少叙,咱进入正题
在这里插入图片描述

2、配置

2.1 数据集

该数据集共包含 8479 幅图像。
它包括 6 个类别标签,每个标签的实例总数如下:

类别计数
苹果7049
葡萄7202
菠萝1613
橙色15549
香蕉3536
西瓜1976

2.2 安装包

必要的Python 软件包包括

  • ultralytics
  • sklearn
  • pandas
  • pyyaml

这次实例中,我们使用 k=5 折叠次数

3、 实战

3.1 生成物体检测数据集的特征向量

具体步骤如下:

  • 1、首先创建一个新的 demo.py Python 文件来执行下面的步骤。

  • 2、继续检索数据集的所有标签文件。

from pathlib import Pathdataset_path = Path("./Fruit-detection")  # replace with 'path/to/dataset' for your custom data
labels = sorted(dataset_path.rglob("*labels/*.txt"))  # all data in 'labels'
  • 3、现在,读取数据集 YAML 文件的内容并提取类标签的索引。
yaml_file = "path/to/data.yaml"  # your data YAML with data directories and names dictionary
with open(yaml_file, "r", encoding="utf8") as y:classes = yaml.safe_load(y)["names"]
cls_idx = sorted(classes.keys())
  • 4、初始化一个空的 pandas DataFrame.
import pandas as pdindex = [label.stem for label in labels]  # uses base filename as ID (no extension)
labels_df = pd.DataFrame([], columns=cls_idx, index=index)
  • 5、计算注释文件中每个类别标签的实例数。
from collections import Counterfor label in labels:lbl_counter = Counter()with open(label, "r") as lf:lines = lf.readlines()for line in lines:# classes for YOLO label uses integer at first position of each linelbl_counter[int(line.split(" ")[0])] += 1labels_df.loc[label.stem] = lbl_counterlabels_df = labels_df.fillna(0.0)  # replace `nan` values with `0.0`
  • 6、以下是已填充 DataFrame 的示例视图:
                                                       0    1    2    3    4    5
'0000a16e4b057580_jpg.rf.00ab48988370f64f5ca8ea4...'  0.0  0.0  0.0  0.0  0.0  7.0
'0000a16e4b057580_jpg.rf.7e6dce029fb67f01eb19aa7...'  0.0  0.0  0.0  0.0  0.0  7.0
'0000a16e4b057580_jpg.rf.bc4d31cdcbe229dd022957a...'  0.0  0.0  0.0  0.0  0.0  7.0
'00020ebf74c4881c_jpg.rf.508192a0a97aa6c4a3b6882...'  0.0  0.0  0.0  1.0  0.0  0.0
'00020ebf74c4881c_jpg.rf.5af192a2254c8ecc4188a25...'  0.0  0.0  0.0  1.0  0.0  0.0...                                                  ...  ...  ...  ...  ...  ...
'ff4cd45896de38be_jpg.rf.c4b5e967ca10c7ced3b9e97...'  0.0  0.0  0.0  0.0  0.0  2.0
'ff4cd45896de38be_jpg.rf.ea4c1d37d2884b3e3cbce08...'  0.0  0.0  0.0  0.0  0.0  2.0
'ff5fd9c3c624b7dc_jpg.rf.bb519feaa36fc4bf630a033...'  1.0  0.0  0.0  0.0  0.0  0.0
'ff5fd9c3c624b7dc_jpg.rf.f0751c9c3aa4519ea3c9d6a...'  1.0  0.0  0.0  0.0  0.0  0.0
'fffe28b31f2a70d4_jpg.rf.7ea16bd637ba0711c53b540...'  0.0  6.0  0.0  0.0  0.0  0.0

解析

  • 行是标签文件的索引,每个标签文件对应数据集中的一幅图像,列则对应类标签索引。
  • 每一行代表一个伪特征向量,其中包含数据集中每个类标签的计数。
  • 这种数据结构可以将 K 折交叉验证应用于对象检测数据集。

3.2 K 折数据集拆分

  • 1、使用 KFold 从 sklearn.model_selection 以产生 k 对数据集进行分割。

    • 敲黑板:
      • 设置 shuffle=True 确保了分班中班级的随机分布。
      • 通过设置 random_state=M 其中 M 是一个选定的整数,这样就可以得到可重复的结果。
from sklearn.model_selection import KFoldksplit = 5
kf = KFold(n_splits=ksplit, shuffle=True, random_state=20)  # setting random_state for repeatable resultskfolds = list(kf.split(labels_df))
  • 2、数据集现已分为 k 折叠,每个折叠都有一个 train 和 val 指数。我们将构建一个 DataFrame 来更清晰地显示这些结果。
folds = [f"split_{n}" for n in range(1, ksplit + 1)]
folds_df = pd.DataFrame(index=index, columns=folds)for i, (train, val) in enumerate(kfolds, start=1):folds_df[f"split_{i}"].loc[labels_df.iloc[train].index] = "train"folds_df[f"split_{i}"].loc[labels_df.iloc[val].index] = "val"
  • 3、将计算每个褶皱的类别标签分布,并将其作为褶皱中出现的类别的比率。
fold_lbl_distrb = pd.DataFrame(index=folds, columns=cls_idx)for n, (train_indices, val_indices) in enumerate(kfolds, start=1):train_totals = labels_df.iloc[train_indices].sum()val_totals = labels_df.iloc[val_indices].sum()# To avoid division by zero, we add a small value (1E-7) to the denominatorratio = val_totals / (train_totals + 1e-7)fold_lbl_distrb.loc[f"split_{n}"] = ratio
最理想的情况是,每次分割和不同类别的所有类别比率都相当相似。不过,这取决于数据集的具体情况。
  • 4、为每个分割创建目录和数据集 YAML 文件。
import datetimesupported_extensions = [".jpg", ".jpeg", ".png"]# Initialize an empty list to store image file paths
images = []# Loop through supported extensions and gather image files
for ext in supported_extensions:images.extend(sorted((dataset_path / "images").rglob(f"*{ext}")))# Create the necessary directories and dataset YAML files (unchanged)
save_path = Path(dataset_path / f"{datetime.date.today().isoformat()}_{ksplit}-Fold_Cross-val")
save_path.mkdir(parents=True, exist_ok=True)
ds_yamls = []for split in folds_df.columns:# Create directoriessplit_dir = save_path / splitsplit_dir.mkdir(parents=True, exist_ok=True)(split_dir / "train" / "images").mkdir(parents=True, exist_ok=True)(split_dir / "train" / "labels").mkdir(parents=True, exist_ok=True)(split_dir / "val" / "images").mkdir(parents=True, exist_ok=True)(split_dir / "val" / "labels").mkdir(parents=True, exist_ok=True)# Create dataset YAML filesdataset_yaml = split_dir / f"{split}_dataset.yaml"ds_yamls.append(dataset_yaml)with open(dataset_yaml, "w") as ds_y:yaml.safe_dump({"path": split_dir.as_posix(),"train": "train","val": "val","names": classes,},ds_y,)
  • 5、最后,将图像和标签复制到每个分割的相应目录("train "或 “val”)中。
import shutilfor image, label in zip(images, labels):for split, k_split in folds_df.loc[image.stem].items():# Destination directoryimg_to_path = save_path / split / k_split / "images"lbl_to_path = save_path / split / k_split / "labels"# Copy image and label files to new directory (SamefileError if file already exists)shutil.copy(image, img_to_path / image.name)shutil.copy(label, lbl_to_path / label.name)

3.3 保存记录

将 K 折分割和标签分布数据框的记录保存为 CSV 文件。

folds_df.to_csv(save_path / "kfold_datasplit.csv")
fold_lbl_distrb.to_csv(save_path / "kfold_label_distribution.csv")

3.4 使用 K 折数据分割训练YOLO

  • 首先,加载YOLO 模型。
from ultralytics import YOLOweights_path = "path/to/weights.pt"
model = YOLO(weights_path, task="detect")
  • 其次,遍历数据集 YAML 文件以运行训练。结果将保存到由 project 和 name 参数。默认情况下,该目录为 “exp/runs#”,其中 # 为整数索引。
results = {}# Define your additional arguments here
batch = 16
project = "kfold_demo"
epochs = 100for k in range(ksplit):dataset_yaml = ds_yamls[k]model = YOLO(weights_path, task="detect")model.train(data=dataset_yaml, epochs=epochs, batch=batch, project=project)  # include any train argumentsresults[k] = model.metrics  # save output metrics for further analysis

4、总结

这篇小鱼使用了 K 折交叉验证来训练YOLO 物体检测模型的过程。

还创建报告 DataFrames 的程序,以可视化数据拆分和标签在这些拆分中的分布,清楚地了解训练集和验证集的结构。

此外,还保存了记录,这在大型项目或排除模型性能故障时尤为有用。

最后,在一个循环中使用每个拆分来执行实际的模型训练,保存训练结果,以便进一步分析和比较。

这种 K 折交叉验证技术是充分利用可用数据的一种稳健方法,有助于确保模型在不同数据子集中的性能是可靠和一致的。这将产生一个更具通用性和可靠性的模型,从而减少对特定数据模式的过度拟合。

我是小鱼

  • CSDN 博客专家
  • 阿里云 专家博主
  • 51CTO博客专家
  • 企业认证金牌面试官
  • 多个名企认证&特邀讲师等
  • 名企签约职场面试培训、职场规划师
  • 多个国内主流技术社区的认证专家博主
  • 多款主流产品(阿里云等)评测一等奖获得者

关注小鱼,学习【人工智能&大模型】/【深度学习&机器学习】领域最新最全的知识。


http://www.ppmy.cn/embedded/160286.html

相关文章

langchain教程-3.OutputParser/输出解析

前言 该系列教程的代码: https://github.com/shar-pen/Langchain-MiniTutorial 我主要参考 langchain 官方教程, 有选择性的记录了一下学习内容 这是教程清单 1.初试langchain2.prompt3.OutputParser/输出解析4.model/vllm模型部署和langchain调用5.DocumentLoader/多种文档…

旋转变压器工作及解调原理

旋转变压器 旋转变压器是一种精密的位置、速度检测装置,广泛应用在伺服控制、机器人、机械工具、汽车、电力等领域。但是,旋转变压器在使用时并不能直接提供角度或位置信息,需要特殊的激励信号和解调、计算措施,才能将旋转变压器…

C++语法·十伞

目录 仿函数 1.定义 2.作用 3.实现 deque(双端队列) 优点: 缺点: stack(栈) 1.使用 2.模拟实现 queue(队列) 1.使用 2.模拟实现 priority_queue(优先级队列…

PVE 中 Debian 虚拟机崩溃后,硬盘数据怎么恢复

问题 在 PVE 中给 Debian 虚拟机新分配硬盘后,通过 Debian 虚拟机开启 Samba 共享该硬盘。如果这个 Debian 虚拟机崩溃后,怎么恢复 Samba 共享硬盘数据。 方法 开启 Samba 共享相关知识:挂载硬盘和开启Samba共享。 新建一个虚拟机&#xf…

数据库------------

一 mysql ----数据库就相当于一个端口 1. 三层结构 1)数据库中 表的本质仍然是文件 1.1 mysql常用数据类型---(即 mysql列类型) 1) 数值类型 2) 文本类型 3) 二进制数据类型 4)日期类型 2. sq…

Linux提权--John碰撞密码提权

​John the Ripper​(简称 John)是一个常用的密码破解工具,可以通过暴力破解、字典攻击、规则攻击等方式,尝试猜解用户密码。密码的弱度是提权攻击中的一个重要因素,如果某个用户的密码非常简单或是默认密码&#xff0…

python:如何播放 .spx 声音文件

.spx 是 Speex音频编解码器的文件扩展名,它是一种开源的、免费的音频编解码器,主要用于语音压缩和语音通信领域。spx 文件通常用于语音记录、VoIP应用、语音信箱等场景。 .mp3 是一种广泛使用的音频格式,它采用了有损压缩算法,可…

langchain教程-12.Agent/工具定义/Agent调用工具/Agentic RAG

前言 该系列教程的代码: https://github.com/shar-pen/Langchain-MiniTutorial 我主要参考 langchain 官方教程, 有选择性的记录了一下学习内容 这是教程清单 1.初试langchain2.prompt3.OutputParser/输出解析4.model/vllm模型部署和langchain调用5.DocumentLoader/多种文档…