本文是看了子豪兄视频以后做的笔记,子豪兄视频,子豪兄笔记
MMSegmentation
是语义分割框架,优点是用这一个框架能跑很多模型,且配置统一,一个数据集跑所有算法。
1-标注数据集
这部分看视频即可,重点就是需要转为掩码图像,没有json标注文件,一张图片对应一张mask图,且划分好路径
2-准备数据集配置文件和pipeline
MMSegmentation最重要的就是配置文件,主要分三步
- 在
./mmseg/datasets
里加入一个类,目的是识别目标类别和对应的 RGB配色,指定图像扩展名、标注扩展名,这里叫ZihaoDataset.py
# 同济子豪兄 2023-6-25
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset@DATASETS.register_module()
class ZihaoDataset(BaseSegDataset):# 类别和对应的 RGB配色METAINFO = {'classes':['background', 'red', 'green', 'white', 'seed-black', 'seed-white'],'palette':[[127,127,127], [200,0,0], [0,200,0], [144,238,144], [30,30,30], [251,189,8]]}# 指定图像扩展名、标注扩展名def __init__(self,seg_map_suffix='.png', # 标注mask图像的格式reduce_zero_label=False, # 类别ID为0的类别是否需要除去**kwargs) -> None:super().__init__(seg_map_suffix=seg_map_suffix,reduce_zero_label=reduce_zero_label,**kwargs)
- 在
./mmseg/datasets
的init.py文件中注册
- 在
./config/_base_/datasets
准备pipeline
文件,里面写了数据集路径和整个训练流程
# 数据处理 pipeline
# 同济子豪兄 2023-6-28# 数据集路径
# 定义数据集类型和数据根目录,适配mmsegmentation框架
# dataset_type: 数据集类名
# data_root: 数据集路径,相对于mmsegmentation主目录dataset_type = 'ZihaoDataset' # 数据集类名
data_root = 'Watermelon87_Semantic_Seg_Mask/' # 数据集路径(相对于mmsegmentation主目录)# 输入模型的图像裁剪尺寸
# crop_size: 输入图像的裁剪尺寸,建议选择128的倍数;尺寸越小,显存开销越小
crop_size = (512, 512)# 训练预处理
# train_pipeline: 定义图像数据在训练时的预处理流程
train_pipeline = [dict(type='LoadImageFromFile'), # 从文件中加载图像dict(type='LoadAnnotations'), # 加载标注信息dict(type='RandomResize', # 随机调整图像大小scale=(2048, 1024), # 调整后的最大尺寸ratio_range=(0.5, 2.0), # 缩放比例范围keep_ratio=True # 是否保持宽高比例),dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), # 随机裁剪图像dict(type='RandomFlip', prob=0.5), # 随机水平翻转dict(type='PhotoMetricDistortion'), # 应用光度失真(亮度、对比度等变化)dict(type='PackSegInputs') # 打包为分割输入格式
]# 测试预处理
# test_pipeline: 定义图像数据在测试时的预处理流程
test_pipeline = [dict(type='LoadImageFromFile'), # 从文件中加载图像dict(type='Resize', scale=(2048, 1024), keep_ratio=True), # 调整图像大小并保持比例dict(type='LoadAnnotations'), # 加载标注信息dict(type='PackSegInputs') # 打包为分割输入格式
]# TTA后处理
# tta_pipeline: 定义测试时的多尺度测试增强(Test Time Augmentation)流程
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] # 不同缩放比例# 多尺度增强流程
# TestTimeAug用于生成多尺度和翻转的增强数据
tta_pipeline = [dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),dict(type='TestTimeAug',transforms=[[dict(type='Resize', scale_factor=r, keep_ratio=True) # 不同缩放比例for r in img_ratios],[dict(type='RandomFlip', prob=0., direction='horizontal'), # 不翻转dict(type='RandomFlip', prob=1., direction='horizontal') # 水平翻转],[dict(type='LoadAnnotations')], # 加载标注信息[dict(type='PackSegInputs')] # 打包为分割输入格式])
]# 训练 Dataloader
# 定义训练数据加载器的配置
train_dataloader = dict(batch_size=2, # 每批次的样本数num_workers=2, # 数据加载的线程数persistent_workers=True, # 是否保持工作线程常驻sampler=dict(type='InfiniteSampler', shuffle=True), # 数据采样策略,使用无限采样并打乱顺序dataset=dict(type=dataset_type, # 数据集类型data_root=data_root, # 数据集根目录data_prefix=dict(img_path='img_dir/train', # 训练集图像路径seg_map_path='ann_dir/train' # 训练集标注路径),pipeline=train_pipeline # 数据预处理流水线)
)# 验证 Dataloader
# 定义验证数据加载器的配置
val_dataloader = dict(batch_size=1, # 每批次的样本数num_workers=4, # 数据加载的线程数persistent_workers=True, # 是否保持工作线程常驻sampler=dict(type='DefaultSampler', shuffle=False), # 数据采样策略,使用默认采样且不打乱顺序dataset=dict(type=dataset_type, # 数据集类型data_root=data_root, # 数据集根目录data_prefix=dict(img_path='img_dir/val', # 验证集图像路径seg_map_path='ann_dir/val' # 验证集标注路径),pipeline=test_pipeline # 数据预处理流水线)
)# 测试 Dataloader
# 定义测试数据加载器,直接复用验证数据加载器的配置
test_dataloader = val_dataloader# 验证 Evaluator
# 定义验证指标计算方式
val_evaluator = dict(type='IoUMetric', # 使用交并比(IoU)作为评估指标iou_metrics=['mIoU', 'mDice', 'mFscore'] # 包括平均IoU、Dice系数和F1分数
)# 测试 Evaluator
# 定义测试指标计算方式,复用验证指标配置
test_evaluator = val_evaluator
3-选择语义分割算法,开始训练
原理就是把你选择的算法文件和你的pipeline
文件做个融合,然后调整一些超参.这里随便选择一个,PSPNET
,其他模型大同小异
载入模型config
配置文件然后微调为新的config
# 载入config配置文件
from mmengine import Configcfg = Config.fromfile('./configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py')
dataset_cfg = Config.fromfile('./configs/_base_/datasets/ZihaoDataset_pipeline.py')
cfg.merge_from_dict(dataset_cfg)
# 类别个数
NUM_CLASS = 6
cfg.model.data_preprocessor.size = cfg.crop_size# 单卡训练时,需要把 SyncBN 改成 BN
cfg.norm_cfg = dict(type='BN', requires_grad=True)
cfg.model.backbone.norm_cfg = cfg.norm_cfg
cfg.model.decode_head.norm_cfg = cfg.norm_cfg
cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg# 模型 decode/auxiliary 输出头,指定为类别个数
cfg.model.decode_head.num_classes = NUM_CLASS
cfg.model.auxiliary_head.num_classes = NUM_CLASS# 训练 Batch Size
cfg.train_dataloader.batch_size = 4# 结果保存目录
cfg.work_dir = './work_dirs/ZihaoDataset-PSPNet'# 模型保存与日志记录
cfg.train_cfg.max_iters = 40000 # 训练迭代次数
cfg.train_cfg.val_interval = 500 # 评估模型间隔
cfg.default_hooks.logger.interval = 100 # 日志记录间隔
cfg.default_hooks.checkpoint.interval = 2500 # 模型权重保存间隔
cfg.default_hooks.checkpoint.max_keep_ckpts = 1 # 最多保留几个模型权重
cfg.default_hooks.checkpoint.save_best = 'mIoU' # 保留指标最高的模型权重# 随机数种子
cfg['randomness'] = dict(seed=0)
# 保存最终的config配置文件
cfg.dump('Zihao-Configs/ZihaoDataset_PSPNet_20230818.py')
4-开始训练
!python tools/train.py Zihao-Configs/ZihaoDataset_PSPNet_20230818.py
日志文件会在work_dirs
中,文件夹名字即为你的算法名
5-评测模型
在work_dirs文件夹中找到配置文件和想要评测的pth权重文件即可
!python tools/test.py Zihao-Configs/ZihaoDataset_PSPNet_20230818.py checkpoint/Zihao_PSPNet.pth
特别注意,在同一个benchmark上面对比性能指标,才有意义,简单说就是同一个测试集,如果我们俩测试集不同,那指标没有意义
测试结果解释
指标 | 含义 |
---|---|
IoU (Intersection over Union) | 交并比:衡量预测与真实标签的重叠程度,值越高越好。 |
Acc (Accuracy) | 准确率:模型在该类别上的预测正确率。 |
Dice (Dice Coefficient) | Dice系数:也是衡量预测与真实标签的重叠程度,适用于不均衡数据,值越高越好。 |
Fscore | F1分数:综合了Precision和Recall的调和平均,值越高越好。 |
Precision | 精确率:预测为该类别的样本中实际为该类别的比例,衡量预测结果的准确性。 |
Recall | 召回率:实际为该类别的样本中被正确预测为该类别的比例,衡量模型的覆盖能力。 |
6-用训练得到的模型预测-单张图片
6-1进入目录
import os
os.chdir('mmsegmentation')
6-2导入包
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inlinefrom mmseg.apis import init_model, inference_model, show_result_pyplot
import mmcv
import cv2
6-3选择配置文件和模型并载入
# 模型 config 配置文件
config_file = 'Zihao-Configs/ZihaoDataset_PSPNet_20230818.py'# 模型 checkpoint 权重文件
checkpoint_file = 'checkpoint/Zihao_PSPNet.pth'# device = 'cpu'
device = 'cuda:0'model = init_model(config_file, checkpoint_file, device=device)
6-4选择图片地址 并显示图片
img_path = 'Watermelon87_Semantic_Seg_Mask/img_dir/val/01bd15599c606aa801201794e1fa30.jpg'
img_bgr = cv2.imread(img_path)
plt.figure(figsize=(8, 8))
plt.imshow(img_bgr[:,:,::-1])
plt.show()
6-5推理
result = inference_model(model, img_bgr)
6-6根据dataset配置文件的颜色进行可视化
from mmseg.apis import show_result_pyplot
img_viz = show_result_pyplot(model, img_path, result, opacity=0.8, title='MMSeg', out_file='outputs/K1-4.jpg')
plt.figure(figsize=(14, 8))
plt.imshow(img_viz)
plt.show()