YOLOv5 分类模型 数据集加载 3

news/2025/1/12 21:37:55/

YOLOv5 分类模型 数据集加载 3 自定义类别

flyfish

YOLOv5 分类模型 数据集加载 1 样本处理
YOLOv5 分类模型 数据集加载 2 切片处理
YOLOv5 分类模型的预处理(1) Resize 和 CenterCrop
YOLOv5 分类模型的预处理(2)ToTensor 和 Normalize
YOLOv5 分类模型 Top 1和Top 5 指标说明
YOLOv5 分类模型 Top 1和Top 5 指标实现

之前的处理方式是类别名字是文件夹名字,类别ID是按照文件夹名字的字母顺序
现在是类别名字是文件夹名字,按照文件列表名字顺序 例如

classes_name=['n02086240', 'n02087394', 'n02088364', 'n02089973', 'n02093754', 
'n02096294', 'n02099601', 'n02105641', 'n02111889', 'n02115641']

n02086240 类别ID是0
n02087394 类别ID是1
代码处理是

if classes_name is None or not classes_name:classes, class_to_idx = self.find_classes(self.root)print("not classes_name")else:classes = classes_nameclass_to_idx ={cls_name: i for i, cls_name in enumerate(classes)}print("is classes_name")

完整

import time
from models.common import DetectMultiBackend
import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
import cv2
import numpy as npimport torch
from PIL import Image
import torchvision.transforms as transformsimport sysclasses_name=['n02086240', 'n02087394', 'n02088364', 'n02089973', 'n02093754', 'n02096294', 'n02099601', 'n02105641', 'n02111889', 'n02115641']class DatasetFolder:def __init__(self,root: str,) -> None:self.root = rootif classes_name is None or not classes_name:classes, class_to_idx = self.find_classes(self.root)print("not classes_name")else:classes = classes_nameclass_to_idx ={cls_name: i for i, cls_name in enumerate(classes)}print("is classes_name")print("classes:",classes)print("class_to_idx:",class_to_idx)samples = self.make_dataset(self.root, class_to_idx)self.classes = classesself.class_to_idx = class_to_idxself.samples = samplesself.targets = [s[1] for s in samples]@staticmethoddef make_dataset(directory: str,class_to_idx: Optional[Dict[str, int]] = None,) -> List[Tuple[str, int]]:directory = os.path.expanduser(directory)if class_to_idx is None:_, class_to_idx = self.find_classes(directory)elif not class_to_idx:raise ValueError("'class_to_index' must have at least one entry to collect any samples.")instances = []available_classes = set()for target_class in sorted(class_to_idx.keys()):class_index = class_to_idx[target_class]target_dir = os.path.join(directory, target_class)if not os.path.isdir(target_dir):continuefor root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):for fname in sorted(fnames):path = os.path.join(root, fname)if 1:  # 验证:item = path, class_indexinstances.append(item)if target_class not in available_classes:available_classes.add(target_class)empty_classes = set(class_to_idx.keys()) - available_classesif empty_classes:msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "return instancesdef find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())if not classes:raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}return classes, class_to_idxdef __getitem__(self, index: int) -> Tuple[Any, Any]:path, target = self.samples[index]sample = self.loader(path)return sample, targetdef __len__(self) -> int:return len(self.samples)def loader(self, path):print("path:", path)#img = cv2.imread(path)  # BGR HWCimg=Image.open(path).convert("RGB") # RGB HWCreturn imgdef time_sync():return time.time()#sys.exit() 
dataset = DatasetFolder(root="/media/a/flyfish/source/yolov5/datasets/imagewoof/val")#image, label=dataset[7]#
weights = "/home/a/classes.pt"
device = "cpu"
model = DetectMultiBackend(weights, device=device, dnn=False, fp16=False)
model.eval()
print(model.names)
print(type(model.names))mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
def preprocess(images):#实现 PyTorch Resizetarget_size =224img_w = images.widthimg_h = images.heightif(img_h >= img_w):# hwresize_img = images.resize((target_size, int(target_size * img_h / img_w)), Image.BILINEAR)else:resize_img = images.resize((int(target_size * img_w  / img_h),target_size), Image.BILINEAR)#实现 PyTorch CenterCropwidth = resize_img.widthheight = resize_img.heightcenter_x,center_y = width//2,height//2left = center_x - (target_size//2)top = center_y- (target_size//2)right =center_x +target_size//2bottom = center_y+target_size//2cropped_img = resize_img.crop((left, top, right, bottom))#实现 PyTorch ToTensor Normalizeimages = np.asarray(cropped_img)print("preprocess:",images.shape)images = images.astype('float32')images = (images/255-mean)/stdimages = images.transpose((2, 0, 1))# HWC to CHWprint("preprocess:",images.shape)images = np.ascontiguousarray(images)images=torch.from_numpy(images)#images = images.unsqueeze(dim=0).float()return imagespred, targets, loss, dt = [], [], 0, [0.0, 0.0, 0.0]
# current batch size =1
for i, (images, labels) in enumerate(dataset):print("i:", i)im = preprocess(images)images = im.unsqueeze(0).to("cpu").float()print(images.shape)t1 = time_sync()images = images.to(device, non_blocking=True)t2 = time_sync()# dt[0] += t2 - t1y = model(images)y=y.numpy()#print("y:", y)t3 = time_sync()# dt[1] += t3 - t2#batch size >1 图像推理结果是二维的#y: [[     4.0855     -1.1707     -1.4998      -0.935     -1.9979      -2.258     -1.4691     -1.0867     -1.9042    -0.99979]]tmp1=y.argsort()[:,::-1][:, :5]#batch size =1 图像推理结果是一维的, 就要处理下argsort的维度#y: [     3.7441      -1.135     -1.1293     -0.9422     -1.6029     -2.0561      -1.025     -1.5842     -1.3952     -1.1824]#print("tmp1:", tmp1)pred.append(tmp1)#print("labels:", labels)targets.append(labels)#print("for pred:", pred)  # list#print("for targets:", targets)  # list# dt[2] += time_sync() - t3pred, targets = np.concatenate(pred), np.array(targets)
print("pred:", pred)
print("pred:", pred.shape)
print("targets:", targets)
print("targets:", targets.shape)
correct = ((targets[:, None] == pred)).astype(np.float32)
print("correct:", correct.shape)
print("correct:", correct)
acc = np.stack((correct[:, 0], correct.max(1)), axis=1)  # (top1, top5) accuracy
print("acc:", acc.shape)
print("acc:", acc)
top = acc.mean(0)
print("top1:", top[0])
print("top5:", top[1])

http://www.ppmy.cn/news/1236880.html

相关文章

火电厂电气部分设计

摘要 本文首先根据任务书上所给系统与线路及所有负荷的参数,分析负荷发展趋势。从负荷增长方面阐明了建站的必要性,然后通过对拟建变电站的概括以及出线方向来考虑,并通过对负荷资料的分析,安全,经济及可靠性方面考虑…

C++二分向量算法:最多可以参加的会议数目 II

本题的其它解法 C二分算法:最多可以参加的会议数目 II 本文涉及的基础知识点 二分查找算法合集 题目 给你一个 events 数组,其中 events[i] [startDayi, endDayi, valuei] ,表示第 i 个会议在 startDayi 天开始,第 endDayi …

重新使用hbase前

启动关闭Hadoop和HBase的顺序一定是: 启动Hadoop—>启动HBase—>关闭HBase—>关闭Hadoop 1.挂载共享文件夹到挂载点 sudo mount -t vboxsf virtualmachineShare /mnt/shared2.进入hadoop目录下启动hadoop cd /usr/local/hadoop/ ./sbin/start-all.sh …

windows11上enable WSL

Windows电脑上要配置linux(这里指ubuntu)开发环境,主要有三种方式: 1)在windows上装个虚拟机(比如vmware)。缺点是vmware加载ubuntu后系统会变慢很多,而且需要通过samba来实现window…

【ARM 嵌入式 编译系列 2.3 -- GCC 中指定 ARMv8-M 的 Thumb 指令集参数详细介绍】

请阅读【ARM GCC 编译专栏导读】 上篇文章:【ARM 嵌入式 编译系列 2.2 – 如何在Makefile 中添加编译时间 | 编译作者| 编译 git id】 下篇文章:【ARM 嵌入式 C 入门及渐进 3 – GCC attribute((weak)) 弱符号使用】 文章目录 ARMv8-M 架构Thumb 指令集ARMv8-M 与 Thumb-mth…

机器学习笔记 - 复杂任务的CNN组合

基础CNN架构可通过多种方式进行组合和扩展,从而解决更多、更复杂的任务。 1. 分类和定位 在分类和定位任务中,你不仅需要说出在图像中找到的物体的类别,而且还需指出物体显现在图像中的边界框坐标。这类任务假设在图像中只有一个物体实例。 这个任务可通过在典型的分类网络…

Jmeter+influxdb+grafana监控平台在windows环境的搭建

原理:Jmeter采集的数据存储在infuxdb数据库中,grafana将数据库中的数据在界面上进行展示 一、grafana下载安装 Download Grafana | Grafana Labs 直接选择zip包下载,下载后解压即可,我之前下载过比较老的版本,这里就…

C语言第二十五弹--打印菱形

C语言打印菱形 思路&#xff1a;想要打印一个菱形&#xff0c;可以分为上下两部分&#xff0c;通过观察可以发现上半部分星号的规律是 1 3 5 7故理解为 2对应行数 1 &#xff0c;空格是4 3 2 1故理解为 行数-对应行数-1。 上半部分代码如下 for (int i 0;i < line;i){//上…