使用自制数据集训练YOLO目标检测算法前,需要对数据集进行划分,以下代码可以将数据集的图片和标签分别保存
import os
import shutil
import numpy as np
from tqdm import tqdmdef split_dataset(images_dir, labels_dir, train_ratio=0.7, val_ratio=0.2, test_ratio=0.1, random_seed=None):# # 检查比例是否合法# if train_ratio + val_ratio + test_ratio!= 1.0:# raise ValueError("The sum of train_ratio, val_ratio, and test_ratio must be 1.0")# 获取所有图片文件的名称列表image_files = [f for f in os.listdir(images_dir) if os.path.isfile(os.path.join(images_dir, f))]num_images = len(image_files)# 使用 numpy 的随机排列函数if random_seed is not None:np.random.seed(random_seed)# 先将图片文件列表按照比例分成三个部分,而不是先随机打乱train_indices = np.random.choice(num_images, size=int(num_images * train_ratio), replace=False)remaining_indices = np.setdiff1d(np.arange(num_images), train_indices)val_indices = np.random.choice(remaining_indices, size=int(num_images * val_ratio), replace=False)test_indices = np.setdiff1d(remaining_indices, val_indices)# 创建存储划分结果的目录os.makedirs('train/images', exist_ok=True)os.makedirs('train/labels', exist_ok=True)os.makedirs('val/images', exist_ok=True)os.makedirs('val/labels', exist_ok=True)os.makedirs('test/images', exist_ok=True)os.makedirs('test/labels', exist_ok=True)# 复制文件到相应目录def copy_files(indices, target_image_dir, target_label_dir):progress_bar = tqdm(indices, desc=f"Copying files to {target_image_dir.split('/')[-1]}")for index in progress_bar:image_file = image_files[index]label_file = os.path.splitext(image_file)[0] + '.json' # 修改标签文件后缀为.jsonshutil.copy(os.path.join(images_dir, image_file), os.path.join(target_image_dir, image_file))if os.path.exists(os.path.join(labels_dir, label_file)):shutil.copy(os.path.join(labels_dir, label_file), os.path.join(target_label_dir, label_file))copy_files(train_indices, 'train/images', 'train/labels')copy_files(val_indices, 'val/images', 'val/labels')copy_files(test_indices, 'test/images', 'test/labels')if __name__ == "__main__":images_dir = 'dataset\\images'labels_dir = 'dataset\\labels'split_dataset(images_dir, labels_dir, random_seed=42)
只需要修改对应的原始未划分的数据集图片和标签路径即可
images_dir = 'dataset\\images'labels_dir = 'dataset\\labels'
今天不学习,明天变垃圾!