Reid strong baseline 代码详解

news/2024/11/29 20:48:22/

本项目是对Reid strong baseline代码的详解。项目暂未加入目标检测部分,后期会不定时更新,请持续关注。

本相比Reid所用数据集为Markt1501,支持Resnet系列作为训练的baseline网络。训练采用表征学习+度量学习的方式。

目录

训练参数

训练代码

create_supervised_trainer(创建训练函数)

create_supervised_evaluator(创建测试函数)

do_train代码 

训练期权重的保存

获得Loss和acc

获取初始epoch 

学习率的调整

loss和acc的打印

时间函数的打印

测试结果的打印以及权重的保存

完整代码 

测试

Reid相关资料学习链接

项目代码:

后期计划更新


训练参数

--last stride:作为Resnet 最后一层layer的步长,默认为1;

--model_path:预训练权重

--model_name:模型名称,支持Resnet系列【详见readme】

--neck:bnneck

--neck_feat:after

--INPUT_SIZE:[256,128],输入大小

--INPUT_MEAN:

--INPUT_STD:

--PROB: 默认0.5

--padding:默认10

--num_workers:默认4【根据自己电脑配置来】

--DATASET_NAME:markt1501,数据集名称

--DATASET_ROOT_DIR:数据集根目录路径

--SAMPLER: 现仅支持softmax_triplet

--IMS_PER_BATCH:训练时的batch size

--TEST_IMS_PER_BATCH:测试时的batch size

--NUM_INSTANCE:一个batch中每个ID用多少图像,默认为4

--OPTIMIZER_NAME:优化器名称,默认为Adam,支持SGD

--BASE_LR:初始学习率,默认0.00035

--WEIGHT_DECAY:权重衰减

--MARGIN:用于tripletloss,默认0.3

--IF_LABELSMOOTH:标签平滑

--OUTPUT_DIR:权重输出路径

--DEVICE:cuda or cpu

--MAX_EPOCHS:训练迭代次数,默认120

训练代码

def train(args):# 数据集train_loader, val_loader, num_query, num_classes = make_data_loader(args)# modelmodel = build_model(args, num_classes)# 优化器optimizer = make_optimizer(args, model)# lossloss_func = make_loss(args, num_classes)start_epoch = 0scheduler = WarmupMultiStepLR(optimizer, args.STEPS, args.GAMMA, args.WARMUP_FACTOR,args.WARMUP_ITERS, args.WARMUP_METHOD)print('ready train~')do_train(args,model,train_loader,val_loader,optimizer,scheduler,loss_func,num_query,start_epoch)

上述代码中所用处理数据集函数make_data_loader可以参考我另一篇文章:

Reid数据集处理代码详解

在看do_train前需要先看以下内容

log_period表示为打印Log的周期,默认为1;

checkpoint_period:表示为保存权重周期,默认为1;

output_dir:输出路径

device:cuda or cpu

epochs:训练迭代轮数

代码中的create_supervised_trainercreate_supervised_evaluator两个函数,是分别是用来创建监督训练和测试的,是对ignite.engine内训练和测试方法的重写。

create_supervised_trainer(创建训练函数)

规则是在内部实现一个def _update(engine,batch)方法,最后返回Engine(_update)。代码如下。

'''
ignite是一个高级的封装训练和测试库
'''
def create_supervised_trainer(model, optimizer, loss_fn, device=None):""":param model:  (nn.Module) reid model to train:param optimizer:Adam or SGD:param loss_fn: loss function:param device: gpu or cpu:return: Engine"""if device:if torch.cuda.device_count() > 1:model = nn.DataParallel(model)model.to(device)def _update(engine, batch):model.train()optimizer.zero_grad()img, target = batchimg = img.to(device) if torch.cuda.device_count() >= 1 else imgtarget = target.to(device) if torch.cuda.device_count() >= 1 else targetscore, feat = model(img)  # 采用表征+度量loss = loss_fn(score, feat, target)  # 传入三个值,score是fc层后的(hard),feat是池化后的特征,target是标签loss.backward()optimizer.step()# compute accacc = (score.max(1)[1] == target).float().mean()return loss.item(), acc.item()return Engine(_update)

create_supervised_evaluator(创建测试函数)

同理,测试代码也是一样,如下,其中metrics是我们需要计算的评价指标

# 重写create_supervised_evaluator,传入model和metrics,metrics是一个字典用来存储需要度量的指标
def create_supervised_evaluator(model, metrics, device=None):if device:if torch.cuda.device_count() > 1:model = nn.DataParallel(model)model.to(device)def _inference(engine, batch):model.eval()with torch.no_grad():data, pids, camids = batchdata = data.to(device) if torch.cuda.is_available() else datafeat = model(data)return feat, pids, camidsengine = Engine(_inference)for name, metric in metrics.items():metric.attach(engine, name)return engine

do_train代码 

然后看一下do_train中的代码。

训练期权重的保存

这里的trainer就是我们前面创建的监督训练的函数,给该实例添加事件,事件为在每次epoch结束的时候保存一次权重[注意这里保存的权重是将模型的完整结构以及优化器权重都保存下来了]

trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model,'optimizer': optimizer})

获得Loss和acc

# average metric to attach on trainer
RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')

获取初始epoch 

训练前获取开始的epoch,默认为0;

    @trainer.on(Events.STARTED)def start_training(engine):engine.state.epoch = start_epoch

学习率的调整

在训练期间每个epoch开始的时候,会调整学习率

    @trainer.on(Events.EPOCH_STARTED)def adjust_learning_rate(engine):scheduler.step()

loss和acc的打印

该事件发生在每个iteration完成时,而不是epoch完成时。

    @trainer.on(Events.ITERATION_COMPLETED)def log_training_loss(engine):global ITERITER += 1if ITER % log_period == 0:logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}".format(engine.state.epoch, ITER, len(train_loader),engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'],scheduler.get_lr()[0]))if len(train_loader) == ITER:ITER = 0

时间函数的打印

该函数是用来在每个epoch完成的时候打印一下用了多长时间

    # adding handlers using `trainer.on` decorator API@trainer.on(Events.EPOCH_COMPLETED)def print_times(engine):logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'.format(engine.state.epoch, timer.value() * timer.step_count,train_loader.batch_size / timer.value()))logger.info('-' * 10)timer.reset()

测试结果的打印以及权重的保存

该函数用来打印测试结果,比如mAP,Rank,测试后的权重会保存在logs下。命名形式为mAP_xx.pth。【注意这里保存我权重和上面保存的权重是不一样的,这里仅仅保存权重,不包含网络结构和优化器权重】

     @trainer.on(Events.EPOCH_COMPLETED)def log_validation_results(engine):if engine.state.epoch % eval_period == 0:evaluator.run(val_loader)cmc, mAP = evaluator.state.metrics['r1_mAP']logger.info("Validation Results - Epoch: {}".format(engine.state.epoch))text = "mAP:{:.1%}".format(mAP)# logger.info("mAP: {:.1%}".format(mAP))logger.info(text)for r in [1, 5, 10]:logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))torch.save(state_dict, 'logs/mAP_{:.1%}.pth'.format(mAP))return cmc, mAP

完整代码 

def do_train(cfg,model,train_loader,val_loader,optimizer,scheduler,loss_fn,num_query,start_epoch
):log_period = 1checkpoint_period = 1eval_period = 1output_dir = cfg.OUTPUT_DIRdevice = cfg.DEVICEepochs = cfg.MAX_EPOCHSprint("Start training~")trainer = create_supervised_trainer(model, optimizer, loss_fn, device)evaluator = create_supervised_evaluator(model,metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm='yes')},device=device)checkpointer = ModelCheckpoint(output_dir, cfg.model_name, checkpoint_period, n_saved=10, require_empty=False)state_dict = model.state_dict()timer = Timer(average=True)trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model,'optimizer': optimizer})timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)# average metric to attach on trainerRunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')@trainer.on(Events.STARTED)def start_training(engine):engine.state.epoch = start_epoch@trainer.on(Events.EPOCH_STARTED)def adjust_learning_rate(engine):scheduler.step()@trainer.on(Events.ITERATION_COMPLETED)def log_training_loss(engine):global ITERITER += 1if ITER % log_period == 0:logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}".format(engine.state.epoch, ITER, len(train_loader),engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'],scheduler.get_lr()[0]))if len(train_loader) == ITER:ITER = 0# adding handlers using `trainer.on` decorator API@trainer.on(Events.EPOCH_COMPLETED)def print_times(engine):logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'.format(engine.state.epoch, timer.value() * timer.step_count,train_loader.batch_size / timer.value()))logger.info('-' * 10)timer.reset()@trainer.on(Events.EPOCH_COMPLETED)def log_validation_results(engine):if engine.state.epoch % eval_period == 0:evaluator.run(val_loader)cmc, mAP = evaluator.state.metrics['r1_mAP']logger.info("Validation Results - Epoch: {}".format(engine.state.epoch))text = "mAP:{:.1%}".format(mAP)# logger.info("mAP: {:.1%}".format(mAP))logger.info(text)for r in [1, 5, 10]:logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))torch.save(state_dict, 'logs/mAP_{:.1%}.pth'.format(mAP))return cmc, mAPtrainer.run(train_loader, max_epochs=epochs)

训练命令如下: 

python tools/train.py --model_name resnet50_ibn_a --model_path weights/ReID_resnet50_ibn_a.pth --IMS_PER_BATCH 8 --TEST_IMS_PER_BATCH 4 --MAX_EPOCHS 120

会出现如下形式: 

 

=> Market1501 loaded
Dataset statistics:----------------------------------------subset   | # ids | # images | # cameras----------------------------------------train    |   751 |    12936 |         6query    |   750 |     3368 |         6gallery  |   751 |    15913 |         6----------------------------------------2023-05-15 14:30:55.603 | INFO     | engine.trainer:log_training_loss:119 - Epoch[1] Iteration[227/1484] Loss: 6.767, Acc: 0.000, Base Lr: 3.82e-05
2023-05-15 14:30:55.774 | INFO     | engine.trainer:log_training_loss:119 - Epoch[1] Iteration[228/1484] Loss: 6.761, Acc: 0.000, Base Lr: 3.82e-05
2023-05-15 14:30:55.946 | INFO     | engine.trainer:log_training_loss:119 - Epoch[1] Iteration[229/1484] Loss: 6.757, Acc: 0.000, Base Lr: 3.82e-05
2023-05-15 14:30:56.134 | INFO     | engine.trainer:log_training_loss:119 - Epoch[1] Iteration[230/1484] Loss: 6.760, Acc: 0.000, Base Lr: 3.82e-05
2023-05-15 14:30:56.305 | INFO     | engine.trainer:log_training_loss:119 - Epoch[1] Iteration[231/1484] Loss: 6.764, Acc: 0.000, Base Lr: 3.82e-05

每个epoch训练完成后会测试一次mAP:

我这里第一个epoch的mAP达到75.1%,Rank-1:91.7%, Rank-5:97.2%, Rank-10:98.2%。

测试完成后会在log文件下保存一个pth权重,名称为mAPxx.pth,也是用该权重进行测试。

 

2023-05-15 14:35:59.753 | INFO     | engine.trainer:print_times:128 - Epoch 1 done. Time per batch: 261.820[s] Speed: 45.4[samples/s]
2023-05-15 14:35:59.755 | INFO     | engine.trainer:print_times:129 - ----------
The test feature is normalized
2023-05-15 14:39:51.025 | INFO     | engine.trainer:log_validation_results:137 - Validation Results - Epoch: 1
2023-05-15 14:39:51.048 | INFO     | engine.trainer:log_validation_results:140 - mAP:75.1%
2023-05-15 14:39:51.051 | INFO     | engine.trainer:log_validation_results:142 - CMC curve, Rank-1  :91.7%
2023-05-15 14:39:51.051 | INFO     | engine.trainer:log_validation_results:142 - CMC curve, Rank-5  :97.2%
2023-05-15 14:39:51.052 | INFO     | engine.trainer:log_validation_results:142 - CMC curve, Rank-10 :98.2%

测试

测试代码在tools/test.py中,代码和train.py差不多,这里不再细说,该代码是可对评价指标进行测试复现。

命令如下:其中TEST_IMS_PER_BATCH是测试时候的batch size,model_name是网络名称,model_path是你训练好的权重路径。

python tools/test.py --TEST_IMS_PER_BATCH 4 --model_name [your model name] --model_path [your weight path]

Reid相关资料学习链接

Reid损失函数理论讲解:Reid之损失函数理论学习讲解_爱吃肉的鹏的博客-CSDN博客

Reid度量学习Triplet loss代码讲解:Reid度量学习Triplet loss代码解析。_爱吃肉的鹏的博客-CSDN博客

yolov5 reid项目(支持跨视频检索):yolov5_reid【附代码,行人重识别,可做跨视频人员检测】_yolov5行人重识别_爱吃肉的鹏的博客-CSDN博客

yolov3 reid项目(支持跨视频检索):ReID行人重识别(训练+检测,附代码),可做图像检索,陌生人检索等项目_爱吃肉的鹏的博客-CSDN博客


预权重链接:

链接:百度网盘 请输入提取码 提取码:yypn


项目代码:

GitHub - YINYIPENG-EN/reid_strong_baselineContribute to YINYIPENG-EN/reid_strong_baseline development by creating an account on GitHub.https://github.com/YINYIPENG-EN/reid_strong_baseline

后期计划更新

​ 1.引入知识蒸馏训练

​ 2.加入YOLOX进行跨视频检测


http://www.ppmy.cn/news/88967.html

相关文章

URLConnection(三)

文章目录 1. 配置连接2. protected URL url3. protected boolean connected4. protected boolean allowUserInteraction5. protected boolean doInput5. protected boolean doOutput6. protected boolean isModifiedSince7. protected boolean useCaches8. 超时 1. 配置连接 U…

[PyTorch][chapter 35][经典卷积神经网络-1 ]

前言&#xff1a; ILSVRC&#xff08;ImageNet Large Scale Visual Recognition Challenge&#xff09;是近年来机器视觉领域最受追捧也是最具权威的学术竞赛之一&#xff0c;代表了图像领域的最高水平。 ImageNet数据集是ILSVRC竞赛使用的是数据集&#xff0c;由斯坦福大学李…

手机号码在网时长 API 实现广告投放和精准营销案例分析

引言 手机在网时长是指用户在移动网络上的在线时间&#xff0c;包括用户接入网络的时间和断开网络的时间。手机在网时长 API 是一种提供手机在网时长数据的编程接口&#xff0c;为开发者和服务提供商提供了获取和利用这些数据的能力。 本文旨在深入探讨手机在网时长 API 的技…

ElasticSearch 快速上手教程(一)—— ES 的安装

简介 ElasticSearch 是一个开源的搜索引擎&#xff0c;基于 Lucene 开发与构建&#xff0c;是当前流行的企业级搜索引擎&#xff0c;在许多应用场景当中都有使用&#xff0c;如商品全文检索&#xff0c;书籍关键字查询等。在这个系列的文章&#xff0c;会带你从零到一&#xf…

DATAV通过配置nginx代理实现https访问

DATAV通过配置nginx代理实现https访问 首先要确保你的 datav 和 datav_proxy 的界面能用http正常访问 在nginx中添加datav配置 server {listen 8181 ssl;server_name localhost;ssl_certificate server.crt;ssl_certificate_key server.key;ssl_session_cache …

RestCloud荣膺广东省优秀软件产品奖,引领国内数据集成领域!

近日&#xff0c;“2022年广东软件风云榜”名单公布&#xff0c;“谷云ETL数据交换软件”凭借其在助力企业数字化转型升级过程中的卓越表现&#xff0c;荣获由羊城晚报报业集团、广东软件行业协会、广东省大数据协会联合颁发的“优秀软件产品和解决方案”奖。 数字化转型是推动…

04_GC垃圾回收

面试题&#xff1a; JVM内存模型以及分区&#xff0c;需要详细到每个区放什么 堆里面的分区&#xff1a;Eden&#xff0c;survival from to&#xff0c;老年代&#xff0c;各自的特点。 GC的三种收集方法&#xff1a;标记清除、标记整理、复制算法的原理与特点&#xff0c;分…

如何在华为OD机试中获得满分?Java实现【寻找相似单词】一文详解!

✅创作者&#xff1a;陈书予 &#x1f389;个人主页&#xff1a;陈书予的个人主页 &#x1f341;陈书予的个人社区&#xff0c;欢迎你的加入: 陈书予的社区 &#x1f31f;专栏地址: Java华为OD机试真题&#xff08;2022&2023) 文章目录 1. 题目描述2. 输入描述3. 输出描述…