将COCO格式的物体检测数据集划分训练集、验证集和测试集

embedded/2025/3/18 23:12:47/

目录

导入所需库

定义数据集路径

创建输出目录

读取JSON注释文件

随机打乱图像列表

计算划分大小

复制图像到相应文件夹

完整代码


导入所需库

我们需要以下Python库:

os:处理文件路径。

json:读取和写入JSON文件。

numpy:随机打乱图像列表。

shutil:复制图像文件。

import os
import json
import numpy as np
import shutil

定义数据集路径

设置数据集的根目录、图像文件夹和注释文件路径。

根目录:"D:\\dataset"

图像文件夹:"D:\\dataset\\images"

注释文件:"D:\\dataset\\annotations.json"

# 数据集路径(请根据实际情况修改)
dataset_root = "D:\\dataset"
images_folder = os.path.join(dataset_root, "images")
annotations_path = os.path.join(dataset_root, "annotations.json")

 

创建输出目录

在根目录下创建output文件夹,并在其中创建out_train、out_val和out_test子文件夹。

# 输出路径
output_root = os.path.join(dataset_root, "output")
os.makedirs(output_root, exist_ok=True)

train_folder = os.path.join(output_root, "out_train")
val_folder = os.path.join(output_root, "out_val")
test_folder = os.path.join(output_root, "out_test")
os.makedirs(train_folder, exist_ok=True)
os.makedirs(val_folder, exist_ok=True)
os.makedirs(test_folder, exist_ok=True)

读取JSON注释文件

加载COCO格式的JSON文件,提取images(图像信息)、annotations(标注信息)和categories(类别信息)。

# 读取注释文件
with open(annotations_path, "r") as f:
    annotations_data = json.load(f)

# 提取数据
images = annotations_data["images"]
annotations = annotations_data["annotations"]
categories = annotations_data["categories"]

随机打乱图像列表

使用numpy随机打乱图像列表,确保划分的随机性。

# 随机打乱图像列表
np.random.shuffle(images)

计算划分大小

根据图像总数和比例计算训练集和测试集的大小:

假设图像总数为N。

训练集:N * 0.8。

验证集:N * 0.0 = 0。

测试集:N * 0.2。

# 定义划分比例
train_ratio, val_ratio, test_ratio = 0.8, 0, 0.2

# 计算大小
num_images = len(images)
num_train = int(num_images * train_ratio)
num_val = int(num_images * val_ratio)  # 将为0

# 划分图像
train_images = images[:num_train]
val_images = images[num_train:num_train + num_val]  # 空列表
test_images = images[num_train + num_val:]

复制图像到相应文件夹

将训练集和测试集的图像复制到对应的文件夹。

# 复制图像
for img in train_images:
    shutil.copy(os.path.join(images_folder, img["file_name"]), 
                os.path.join(train_folder, img["file_name"]))

for img in val_images:  # 不会执行
    shutil.copy(os.path.join(images_folder, img["file_name"]), 
                os.path.join(val_folder, img["file_name"]))

for img in test_images:
    shutil.copy(os.path.join(images_folder, img["file_name"]), 
                os.path.join(test_folder, img["file_name"]))

完整代码

以下是完整的Python脚本:

import os
import json
import numpy as np
import shutil# 数据集路径(请根据实际情况修改)
dataset_root = "D:\\dataset"
images_folder = os.path.join(dataset_root, "images")
annotations_path = os.path.join(dataset_root, "annotations.json")# 输出路径
output_root = os.path.join(dataset_root, "output")
os.makedirs(output_root, exist_ok=True)train_folder = os.path.join(output_root, "out_train")
val_folder = os.path.join(output_root, "out_val")
test_folder = os.path.join(output_root, "out_test")
os.makedirs(train_folder, exist_ok=True)
os.makedirs(val_folder, exist_ok=True)
os.makedirs(test_folder, exist_ok=True)# 读取注释文件
with open(annotations_path, "r") as f:annotations_data = json.load(f)# 提取数据
images = annotations_data["images"]
annotations = annotations_data["annotations"]
categories = annotations_data["categories"]# 随机打乱图像列表
np.random.shuffle(images)# 定义划分比例
train_ratio, val_ratio, test_ratio = 0.8, 0, 0.2# 计算大小
num_images = len(images)
num_train = int(num_images * train_ratio)
num_val = int(num_images * val_ratio)# 划分图像
train_images = images[:num_train]
val_images = images[num_train:num_train + num_val]
test_images = images[num_train + num_val:]# 复制图像
for img in train_images:shutil.copy(os.path.join(images_folder, img["file_name"]), os.path.join(train_folder, img["file_name"]))for img in val_images:shutil.copy(os.path.join(images_folder, img["file_name"]), os.path.join(val_folder, img["file_name"]))for img in test_images:shutil.copy(os.path.join(images_folder, img["file_name"]), os.path.join(test_folder, img["file_name"]))# 函数:过滤注释
def filter_annotations(annotations, image_ids):return [ann for ann in annotations if ann["image_id"] in image_ids]# 获取image_ids
train_image_ids = [img["id"] for img in train_images]
val_image_ids = [img["id"] for img in val_images]
test_image_ids = [img["id"] for img in test_images]# 过滤注释
train_ann = filter_annotations(annotations, train_image_ids)
val_ann = filter_annotations(annotations, val_image_ids)
test_ann = filter_annotations(annotations, test_image_ids)# 创建JSON字典
train_json = {"images": train_images, "annotations": train_ann, "categories": categories}
val_json = {"images": val_images, "annotations": val_ann, "categories": categories}
test_json = {"images": test_images, "annotations": test_ann, "categories": categories}# 写入JSON文件
with open(os.path.join(output_root, "out_train.json"), "w") as f:json.dump(train_json, f)
with open(os.path.join(output_root, "out_val.json"), "w") as f:json.dump(val_json, f)
with open(os.path.join(output_root, "out_test.json"), "w") as f:json.dump(test_json, f)print("数据集划分完成!")


http://www.ppmy.cn/embedded/173701.html

相关文章

QT:文件读取

问题: 在文件读取,判断md5值时,遇到py文件读取转String后,再转byte,md5前后不一致问题。 解决方法: python文件读取要使用QTextStream,避免\t 、\r、\n的换行符跨平台问题(window…

MySQL 锁

MySQL中最常见的锁有全局锁、表锁、行锁。 全局锁 全局锁用于锁住当前库中的所有实例,也就是说会将所有的表都锁住。一般用于做数据库备份的时候就需要添加全局锁,数据库备份的时候是一个表一个表备份,如果没有加锁的话在备份的时候会有其他的…

选择最佳加密软件:IPguard vs Ping32——企业级安全方案评估

在当前数字化快速发展的背景下,企业信息的安全性变得尤为重要。为了有效保护企业的核心数据和知识产权,选择合适的加密软件成为众多企业关注的焦点。本文将对两款市场上广受好评的加密软件——Ping32与IPguard进行详细对比分析,旨在帮助企业根…

解决从deepseek接口获取的流式响应输出到前端都是undefined的问题

你的前端 EventSource 代码遇到了 undefined 连续输出 的问题,通常是因为: AI 返回的内容被拆成了单个字符,导致前端 JSON.parse(event.data).content 获取到的是单个字符,而 undefined 可能是因为某些数据块没有 content 字段。…

Flutter 按钮组件 ElevatedButton 详解

目录 1. 引言 2. ElevatedButton 的基本用法 3. 主要属性 4. 自定义按钮样式 4.1 修改背景颜色和文本颜色 4.2 修改按钮形状和边框 4.3 修改按钮大小 4.4 阴影控制 4.5 水波纹效果 5. 结论 相关推荐 1. 引言 在 Flutter 中,ElevatedButton 是一个常用的…

低空经济腾飞:无人机送货、空中通勤,未来已来

近年来,低空经济逐渐成为社会关注的焦点。从无人机送货到“空中的士”,再到飞行培训的火热进行,低空经济正迎来前所未有的发展机遇。随着技术进步和政策支持,这一曾经看似遥远的未来场景,正逐步变为现实。 低空经济如何…

音视频处理的“瑞士军刀”与“积木”:FFmpeg 与 GStreamer 的深度揭秘

一、发展历史与生态演进对比 FFmpeg的成长轨迹 诞生背景:2000年由Fabrice Bellard创建,最初为解决视频编码标准化问题而生。早期版本仅支持MPEG-1编码,但凭借开源社区协作,迅速扩展为全格式编解码工具。技术扩张:2004年…

【RHCE实验】搭建主从DNS、WEB等服务器

目录 需求 环境搭建 配置nfs服务器 配置web服务器 配置主从dns服务器 主dns服务器 从dns服务器 配置客户端 客户端测试 需求 客户端通过访问 www.nihao.com 后,能够通过 dns 域名解析,访问到 nginx 服务中由 nfs 共享的首页文件,内容…