【数据集】Yolo人体关键点数据集处理

devtools/2024/10/9 8:23:38/

文章目录

  • 1、介绍
  • 2、数据集格式
  • 3 、COCO人体关键点示意图
  • 4、数据集预处理
    • 4.1、读取JSON文件
    • 4.2、可视化
    • 4.3、JSON格式转Yolo格式
    • 4.4、划分数据集
    • 4.5、验证
  • 5、完整代码
  • 6、数据集下载链接

1、介绍

 人体关键点检测(Human Keypoints Detection)又称为人体姿态估计2D Pose,是计算机视觉中一个相对基础的任务,是人体动作识别、行为分析、人机交互等的前置任务。一般情况下可以将人体关键点检测细分为单人/多人关键点检测、2D/3D关键点检测,同时有算法在完成关键点检测之后还会进行关键点的跟踪,也被称为人体姿态跟踪。

 本次要介绍的数据集是2D关键点检测数据集,数据集主要来自COCO2017,经过对COCO数据集JSON文件进行预处理提取人体关键点信息,一共提取10000张人体姿态数据集,以及对应的必要信息,已经转化为Yolo格式存储。
在这里插入图片描述

2、数据集格式

keypoint_dataset||_______images|		  |_____000000000049.jpg|		  |_____ ......|_______labels|		  |_____000000000049.txt|		  |_____ ......		  

3 、COCO人体关键点示意图

 下图中,共17个关节点(鼻子x1、眼睛x2、耳朵x2、肩部x2、肘部x2、手腕x2、髋部x2、膝关节x2、脚腕x2):
在这里插入图片描述

4、数据集预处理

 这里数据集预处理主要包括5个内容,分别是read_jsonshowjson2yolosplit_datasetvalidation

4.1、读取JSON文件

def read_json(json_path):d = defaultdict(list)for json_name in tqdm(os.listdir(json_path)):json_full_path = json_path + '/' + json_namewith open(json_full_path, 'r', encoding='utf-8') as file:  data = json.load(file)image_name = data['image_name']# category = data['category']# width = data['width']# height = data['height']# keypoints = data['key_points']# bbox = data['bbox']image_id = image_name.split('.')[0]d[int(image_id)]= data print(f'json 文件读取完成,一共{len(d.keys())}个数据')return d

4.2、可视化

def visualization(ax,keypoints,bbox):#随机颜色c = (np.random.random((1, 3))*0.6+0.4).tolist()[0]#关键点之间的连线ls = [[15,13],[13,11],[16,14],[14,12],[11,12],[5,11],[6,12],[5,6],[5,7],[6,8],[7,9],[8,10],[1,2],[0,1],[0,2],[1,3],[2,4],[3,5],[4,6]]sks = np.array(ls)#获取关键点坐标x,y,vkp = np.array(keypoints)x = kp[0::3]y = kp[1::3]v = kp[2::3]for sk in sks:if np.all(v[sk]>0):# 画点之间的连接线plt.plot(x[sk],y[sk], linewidth=1, color=c)# 画点p = plt.plot(x[v>0], y[v>0],'o',markersize=4, markerfacecolor=c, markeredgecolor='k',markeredgewidth=1)p = plt.plot(x[v>1], y[v>1],'o',markersize=4, markerfacecolor=c, markeredgecolor=c, markeredgewidth=1)# 画矩形边界,多边形填充+矩形边界:x, y, w, h = bbox[0],bbox[1],bbox[2],bbox[3]ax.add_patch(Polygon(xy=[[x, y], [x, y+h], [x+w, y+h], [x+w, y]], color='k', alpha=0.3))ax.add_patch(Rectangle(xy=(x, y), width=w, height=h, fill=False, color=c, alpha=1))plt.plot(x + w/2,y+h/2,'*',markersize=5, markerfacecolor=c, markeredgecolor=c, markeredgewidth=1)def show(data,image_root):for image_id in data.keys():img = io.imread('%s/%s' % (image_root, data[image_id]['image_name']))plt.axis('off')ax = plt.gca()for i in range(len(data[image_id]['key_points'])):visualization(ax, data[image_id]['key_points'][i], data[image_id]['bbox'][i])plt.imshow(img)plt.axis('off')plt.show()

4.3、JSON格式转Yolo格式

 yolo标注格式为:类别 、标注框中心点(x,y)、长和宽(w,h),关键点坐标以及可见度(kx1,ky1,kv1,kx2,ky2,kv2…),然后并根据图片长宽进行归一化处理

def json2yolo(data,txt_path):if not os.path.exists(txt_path):os.mkdir(txt_path)for image_id in tqdm(data.keys()):txt_name = data[image_id]['image_name'].split('.')[0]category = data[image_id]['category']width = data[image_id]['width']height = data[image_id]['height']keypoints = data[image_id]['key_points']bbox = data[image_id]['bbox']'''yolo标注格式类别 、标注框中心点、长和宽,关键点坐标以及可见度根据图片长高进行归一化处理'''for i in range(len(keypoints)):#对标注框进行预处理box = bbox[i]keypoint = keypoints[i]# print(keypoint)#x,y为左上角坐标,w,h为宽高x,y,w,h = box[0],box[1],box[2],box[3]xx = x + w/2yy = y + h/2#归一化xx,yy,ww,hh = xx/width,yy/height,w/width,h/height#关键点归一化proc_keypoint = []for p in range(0,len(keypoint),3):x,y,v = keypoint[p],keypoint[p+1],keypoint[p+2]kx,ky = x/width, y/heightproc_keypoint.extend([kx,ky,v])#写入txt文件yolo_str = f'{category} {xx:.6f} {yy:.6f} {ww:.6f} {hh:.6f} 'yolo_str = yolo_str + ' '.join([f'{i:.6f}' for i in proc_keypoint])with open(f'{txt_path}/{txt_name}.txt','a+') as f:f.write(yolo_str + '\n')print('转化为yolo txt格式完成!')

4.4、划分数据集

 划分数据集格式为yolo需要的格式

def split_dataset(image_root,txt_root,img_targe_file,label_target_file,split_ratio=0.6):imgs = os.listdir(image_root)import randomrandom.seed(2024)random.shuffle(imgs)#这里仅仅取了1万张图片进行测试,由于显存限制,ims = imgs[:10000]random.shuffle(ims)train_num = int(len(ims)*split_ratio)val_num = int(len(ims)*(1-split_ratio)/2)train_set = ims[:train_num]val_set = ims[train_num:train_num+val_num]test_set = ims[-val_num:]move_file(train_set,image_root,txt_root,img_targe_file,label_target_file,mode='train')move_file(val_set,image_root,txt_root,img_targe_file,label_target_file,mode='val')move_file(test_set,image_root,txt_root,img_targe_file,label_target_file,mode='test')def move_file(dataset,image_root,txt_root,img_targe_file,label_target_file,mode='train'):for img in tqdm(dataset):img_path = image_root + '/' + imgtxt_path = txt_root + '/' + img.replace('jpg','txt')img_targe_file_full = img_targe_file + '/' + modelabel_target_file_full = label_target_file + '/' + modeshutil.copy(img_path, img_targe_file_full)shutil.copy(txt_path, label_target_file_full)

4.5、验证

def validation():images = r'E:\datasets\keypoint_dataset\datasets\images'labels = r'E:\datasets\keypoint_dataset\datasets\labels'for mode in ['train','val','test']:img_path = images + '/' + modelabel_path = labels + '/' + modeassert len(os.listdir(img_path)) == len(os.listdir(label_path)),f'{mode}数据划分验证失败'for img,lb in zip(os.listdir(img_path),os.listdir(label_path)):img_name = img.split('.')[0]lb_name = lb.split('.')[0]if img_name != lb_name:assert '数据划分验证失败'print('数据划分验证成功!')pass

5、完整代码

import json
from collections import defaultdict
import numpy as np
import os
from tqdm import tqdm
import shutil
from matplotlib import pyplot as plt
import skimage.io as iofrom matplotlib.patches import Polygon,Rectangle
def visualization(ax,keypoints,bbox):#随机颜色c = (np.random.random((1, 3))*0.6+0.4).tolist()[0]#关键点之间的连线ls = [[15,13],[13,11],[16,14],[14,12],[11,12],[5,11],[6,12],[5,6],[5,7],[6,8],[7,9],[8,10],[1,2],[0,1],[0,2],[1,3],[2,4],[3,5],[4,6]]sks = np.array(ls)#获取关键点坐标x,y,vkp = np.array(keypoints)x = kp[0::3]y = kp[1::3]v = kp[2::3]for sk in sks:if np.all(v[sk]>0):# 画点之间的连接线plt.plot(x[sk],y[sk], linewidth=1, color=c)# 画点p = plt.plot(x[v>0], y[v>0],'o',markersize=4, markerfacecolor=c, markeredgecolor='k',markeredgewidth=1)p = plt.plot(x[v>1], y[v>1],'o',markersize=4, markerfacecolor=c, markeredgecolor=c, markeredgewidth=1)# 画矩形边界,多边形填充+矩形边界:x, y, w, h = bbox[0],bbox[1],bbox[2],bbox[3]ax.add_patch(Polygon(xy=[[x, y], [x, y+h], [x+w, y+h], [x+w, y]], color='k', alpha=0.3))ax.add_patch(Rectangle(xy=(x, y), width=w, height=h, fill=False, color=c, alpha=1))plt.plot(x + w/2,y+h/2,'*',markersize=5, markerfacecolor=c, markeredgecolor=c, markeredgewidth=1)def read_json(json_path):d = defaultdict(list)for json_name in tqdm(os.listdir(json_path)):json_full_path = json_path + '/' + json_namewith open(json_full_path, 'r', encoding='utf-8') as file:  data = json.load(file)image_name = data['image_name']# category = data['category']# width = data['width']# height = data['height']# keypoints = data['key_points']# bbox = data['bbox']image_id = image_name.split('.')[0]d[int(image_id)]= data print(f'json 文件读取完成,一共{len(d.keys())}个数据')return d
def show(data,image_root):for image_id in data.keys():img = io.imread('%s/%s' % (image_root, data[image_id]['image_name']))plt.axis('off')ax = plt.gca()for i in range(len(data[image_id]['key_points'])):visualization(ax, data[image_id]['key_points'][i], data[image_id]['bbox'][i])plt.imshow(img)plt.axis('off')plt.show()
def json2yolo(data,txt_path):if not os.path.exists(txt_path):os.mkdir(txt_path)for image_id in tqdm(data.keys()):txt_name = data[image_id]['image_name'].split('.')[0]category = data[image_id]['category']width = data[image_id]['width']height = data[image_id]['height']keypoints = data[image_id]['key_points']bbox = data[image_id]['bbox']'''yolo标注格式类别 、标注框中心点、长和宽,关键点坐标以及可见度根据图片长高进行归一化处理'''for i in range(len(keypoints)):#对标注框进行预处理box = bbox[i]keypoint = keypoints[i]# print(keypoint)#x,y为左上角坐标,w,h为宽高x,y,w,h = box[0],box[1],box[2],box[3]xx = x + w/2yy = y + h/2#归一化xx,yy,ww,hh = xx/width,yy/height,w/width,h/height#关键点归一化proc_keypoint = []for p in range(0,len(keypoint),3):x,y,v = keypoint[p],keypoint[p+1],keypoint[p+2]kx,ky = x/width, y/heightproc_keypoint.extend([kx,ky,v])#写入txt文件yolo_str = f'{category} {xx:.6f} {yy:.6f} {ww:.6f} {hh:.6f} 'yolo_str = yolo_str + ' '.join([f'{i:.6f}' for i in proc_keypoint])with open(f'{txt_path}/{txt_name}.txt','a+') as f:f.write(yolo_str + '\n')print('转化为yolo txt格式完成!')def split_dataset(image_root,txt_root,img_targe_file,label_target_file,split_ratio=0.6):imgs = os.listdir(image_root)import randomrandom.seed(2024)random.shuffle(imgs)#这里仅仅取了1万张图片进行测试,由于显存限制,ims = imgs[:10000]random.shuffle(ims)train_num = int(len(ims)*split_ratio)val_num = int(len(ims)*(1-split_ratio)/2)train_set = ims[:train_num]val_set = ims[train_num:train_num+val_num]test_set = ims[-val_num:]move_file(train_set,image_root,txt_root,img_targe_file,label_target_file,mode='train')move_file(val_set,image_root,txt_root,img_targe_file,label_target_file,mode='val')move_file(test_set,image_root,txt_root,img_targe_file,label_target_file,mode='test')def move_file(dataset,image_root,txt_root,img_targe_file,label_target_file,mode='train'):for img in tqdm(dataset):img_path = image_root + '/' + imgtxt_path = txt_root + '/' + img.replace('jpg','txt')img_targe_file_full = img_targe_file + '/' + modelabel_target_file_full = label_target_file + '/' + modeshutil.copy(img_path, img_targe_file_full)shutil.copy(txt_path, label_target_file_full)
def validation():images = r'E:\datasets\keypoint_dataset\datasets\images'labels = r'E:\datasets\keypoint_dataset\datasets\labels'for mode in ['train','val','test']:img_path = images + '/' + modelabel_path = labels + '/' + modeassert len(os.listdir(img_path)) == len(os.listdir(label_path)),f'{mode}数据划分验证失败'for img,lb in zip(os.listdir(img_path),os.listdir(label_path)):img_name = img.split('.')[0]lb_name = lb.split('.')[0]if img_name != lb_name:assert '数据划分验证失败'print('数据划分验证成功!')passif __name__ == '__main__':root = 'E:/datasets/keypoint_dataset'image_root = 'E:/datasets/keypoint_dataset/images'json_root = 'E:/datasets/keypoint_dataset/labels'txt_root = root + '/'+ 'txt'#读取json文件data = read_json(json_root)#可视化# show(data,image_root)#json2yolo# json2yolo(data,txt_root)#split dataset# img_target_file = root + '/'+ 'datasets' + '/images'# label_target_file = root + '/'+ 'datasets' + '/labels'# split_dataset(image_root,txt_root,img_target_file,label_target_file,split_ratio=0.7)#validation验证validation()

6、数据集下载链接

数据集下载链接


http://www.ppmy.cn/devtools/91382.html

相关文章

使用 Gunicorn 部署 Flask 项目

使用 Gunicorn 部署 Flask 项目 1. 简介 Flask 自带的 web 服务器仅适用于开发环境,无法满足生产环境的性能需求。在使用 app.run(host0.0.0.0, port5000) 启动时,Flask 会发出警告:WARNING: This is a development server. Do not use it …

c++----初识模板

大家好,这篇博客想与大家分享一些我们c中比较好用的知识点。模板。首先咧,我们都知道模板嘛,就是以前人的经验总结出来的知识。方便我们使用。这里的模板也是一样的。当我们学习过后,对于一些在c中的自定义函数,我们在…

sql 中的group by 与 聚合函数

聚合函数 MAX( )函数取指定字段的最大值; MIN( )函数取指定字段的最小值; SUM( ) 函数对指定字段的值进行求和; COUNT( ) 函数计算某个分组内数据的条数; AVG( ) 函数指定字段的值求平均数。 举例: …

探索人工智能大模型在工业领域的应用与发展

探索人工智能大模型在工业领域的应用与发展 前言测评总结 前言 人工智能大模型在工业领域的应用正逐渐展现出其巨大的潜力。大模型能够在工业知识问答、工程建模、数据分析、文档生成和代码理解等多个场景中发挥重要作用。 例如,在工业知识问答方面,大…

React管理系统整合Cesium避坑指南

花费了一周时间将React 升级到了最新版本18,同时整合Cesium三维模块到系统中,其中遇到了react 版本升级后模块删改,按照原来的引入方式无法使用的问题,以及Cesium 放入子路由一直404等问题 文章目录 一、系统版本依赖二、系统预览…

回归预测|基于雪消融优化极端梯度提升树的数据回归预测Matlab程序SAO-XGBoost多特征输入单输出 含基础模型

回归预测|基于雪消融优化极端梯度提升树的数据回归预测Matlab程序SAO-XGBoost多特征输入单输出 含基础模型 文章目录 前言回归预测|基于雪消融优化极端梯度提升树的数据回归预测Matlab程序SAO-XGBoost多特征输入单输出 含基础模型 一、SAO-XGBoost模型二、实验结果三、核心代码…

2024年AWS云服务器选择哪个区域最好?

在选择2024年AWS云服务器区域时,您需要根据您的业务需求、目标用户群体的位置、数据合规性要求、延迟需求以及成本预算等因素综合考虑。以下是九河云针对不同需求的建议: 北美区域 优势:北美区域,尤其是弗吉尼亚北部&#xff0c…

模拟退火的

题目链接 体验乱调参数而看天意的奇特体验 #include<bits/stdc.h> using namespace std; typedef long long ll; typedef unsigned long long ull; typedef pair<ll,ll> pii; const int inf0x3f3f3f3f; const int N1e510; const int mod1e97; //#define int long…