深度学习代码分析——自用

embedded/2025/3/5 11:28:59/

代码来自:https://github.com/ChuHan89/WSSS-Tissue?tab=readme-ov-file

借助了一些人工智能

1_train_stage1.py

代码功能总览

该代码是弱监督语义分割(WSSS)流程的 Stage1 训练与测试脚本,核心任务是通过 多标签分类模型 生成图像级标签,为后续生成伪掩码(Pseudo-Masks)提供基础。代码分为 train_phase 和 test_phase 两个阶段,支持 渐进式Dropout注意力(PDA) 和 Visdom可视化监控

1. 依赖库导入

import os
import numpy as np
import argparse
import importlib
from visdom import Visdom  # 可视化工具import torch
import torch.nn.functional as F
from torch.backends import cudnn  # CUDA加速
from torch.utils.data import DataLoader
from torchvision import transforms  # 数据预处理
from tool import pyutils, torchutils  # 自定义工具包
from tool.GenDataset import Stage1_TrainDataset  # 自定义数据集类
from tool.infer_fun import infer  # 测试阶段推理函数cudnn.enabled = True  # 启用CUDA加速(自动优化卷积算法)
  • 关键细节

    • cudnn.enabled=True:启用cuDNN加速,自动选择最优卷积实现。

    • pyutils 和 torchutils:项目自定义工具模块(包含优化器、计时器等)。

    • Visdom:用于实时可视化训练过程中的损失和准确率曲线。

2. 辅助函数 compute_acc

def compute_acc(pred_labels, gt_labels):pred_correct_count = 0for pred_label in pred_labels:  # 遍历预测标签if pred_label in gt_labels:  # 判断是否在真实标签中pred_correct_count += 1union = len(gt_labels) + len(pred_labels) - pred_correct_count  # 并集大小acc = round(pred_correct_count/union, 4)  # 交并比(IoU)式准确率return acc
  • 功能:计算预测标签与真实标签的 交并比准确率(IoU-like Accuracy)。

  • 数学公式

    Acc=预测正确的标签数预测标签数+真实标签数−预测正确的标签数Acc=预测标签数+真实标签数−预测正确的标签数预测正确的标签数​
  • 示例

    • 预测标签:[0, 2],真实标签:[2, 3]

    • 正确数:1(标签2),并集:2 + 2 - 1 = 3 → Acc = 1/3 ≈ 0.333

3. 训练阶段 train_phase

3.1 初始化与模型加载
def train_phase(args):viz = Visdom(env=args.env_name)  # 创建Visdom环境(用于可视化)model = getattr(importlib.import_module(args.network), 'Net')(args.init_gama, n_class=args.n_class)print(vars(args))  # 打印所有输入参数
  • 关键细节

    • 动态模型加载:通过 importlib 从字符串 args.network(如 "network.resnet38_cls")动态加载模型类 Net

    • PDA参数args.init_gama 控制渐进式Dropout注意力的初始强度(值越大,注意力区域越集中)。

    • Visdom环境:通过 env=args.env_name 隔离不同实验的可视化结果。

3.2 数据增强与加载
    transform_train = transforms.Compose([transforms.RandomHorizontalFlip(p=0.5),  # 50%概率水平翻转transforms.RandomVerticalFlip(p=0.5),    # 50%概率垂直翻转transforms.ToTensor()                    # 转为Tensor(范围[0,1])]) train_dataset = Stage1_TrainDataset(data_path=args.trainroot,  # 训练集路径(如'datasets/BCSS-WSSS/train/')transform=transform_train, dataset=args.dataset       # 数据集标识(如'bcss'))train_data_loader = DataLoader(train_dataset,batch_size=args.batch_size,  # 批大小(默认20)shuffle=True,                # 打乱数据顺序num_workers=args.num_workers,  # 数据加载子进程数(默认10)pin_memory=False,             # 不锁页内存(适用于小批量数据)drop_last=True                # 丢弃最后不足一个batch的数据)
  • 关键细节

    • 数据增强策略:仅使用翻转操作,避免复杂变换干扰分类模型的学习。

    • 自定义数据集类Stage1_TrainDataset 需实现图像和标签的加载逻辑(如解析XML或CSV文件)。

3.3 优化器配置
    max_step = (len(train_dataset) // args.batch_size) * args.max_epoches  # 总迭代次数param_groups = model.get_parameter_groups()  # 获取模型参数分组(通常按网络层分组)optimizer = torchutils.PolyOptimizer([{'params': param_groups[0], 'lr': args.lr, 'weight_decay': args.wt_dec},  # 主干网络(低学习率){'params': param_groups[1], 'lr': 2*args.lr, 'weight_decay': 0},         # 中间层(较高学习率){'params': param_groups[2], 'lr': 10*args.lr, 'weight_decay': args.wt_dec},  # 分类头(高学习率){'params': param_groups[3], 'lr': 20*args.lr, 'weight_decay': 0}          # 特殊模块(最高学习率)], lr=args.lr, weight_decay=args.wt_dec, max_step=max_step  # 控制学习率衰减)
  • 关键细节

    • 参数分组:不同网络层(如ResNet38的卷积层、全连接层)使用不同的学习率,分类头通常需要更高学习率以快速适应新任务。

    • Poly学习率衰减:学习率按公式 lr=base_lr×(1−stepmax_step)powerlr=base_lr×(1−max_stepstep​)power 衰减,默认 power=0.9

3.4 加载预训练权重
    if args.weights[-7:] == '.params':  # MXNet格式权重(如'init_weights/ilsvrc-cls_rna-a1_cls1000_ep-0001.params')import network.resnet38dweights_dict = network.resnet38d.convert_mxnet_to_torch(args.weights)  # 转换权重格式model.load_state_dict(weights_dict, strict=False)  # 非严格加载(允许部分参数不匹配)elif args.weights[-4:] == '.pth':   # PyTorch格式权重weights_dict = torch.load(args.weights)model.load_state_dict(weights_dict, strict=False)else:print('random init')  # 随机初始化(无预训练)
  • 关键细节

    • MXNet转换:项目可能基于早期MXNet实现,需将预训练权重转换为PyTorch格式。

    • strict=False:允许模型结构与权重文件部分不匹配(如分类头维度不同)。

3.5 训练循环
    model = model.cuda()  # 将模型移至GPUavg_meter = pyutils.AverageMeter('loss', 'avg_ep_EM', 'avg_ep_acc')  # 统计训练指标timer = pyutils.Timer("Session started: ")  # 计时器(计算剩余时间)for ep in range(args.max_epoches):  # 遍历每个epochmodel.train()args.ep_index = ep  # 当前epoch索引(可能用于回调)ep_count = 0        # 当前epoch累计样本数ep_EM = 0           # 完全匹配(Exact Match)次数ep_acc = 0           # 累计准确率for iter, (filename, data, label) in enumerate(train_data_loader):  # 遍历每个batchimg = data  # 图像数据(未使用filename)label = label.cuda(non_blocking=True)  # 标签移至GPU(异步传输)# 控制PDA的启用(前3个epoch禁用)enable_PDA = 1 if ep > 2 else 0# 前向传播(返回分类输出、特征图、概率)x, feature, y = model(img.cuda(), enable_PDA)# 转换为CPU numpy数组以计算指标prob = y.cpu().data.numpy()  # 预测概率(shape=[batch_size, n_class])gt = label.cpu().data.numpy()  # 真实标签(shape=[batch_size, n_class])# 遍历batch内每个样本计算指标for num, one in enumerate(prob):ep_count += 1pass_cls = np.where(one > 0.5)[0]  # 预测标签(概率>0.5的类别)true_cls = np.where(gt[num] == 1)[0]  # 真实标签(one-hot编码中为1的类别)# 统计Exact Match(完全匹配)if np.array_equal(pass_cls, true_cls):ep_EM += 1# 计算交并比式准确率acc = compute_acc(pass_cls, true_cls)ep_acc += acc# 计算当前batch的平均指标avg_ep_EM = round(ep_EM / ep_count, 4)avg_ep_acc = round(ep_acc / ep_count, 4)# 计算多标签分类损失loss = F.multilabel_soft_margin_loss(x, label)  # x为模型原始输出(未经过sigmoid)# 更新统计指标avg_meter.add({'loss': loss.item(),'avg_ep_EM': avg_ep_EM,'avg_ep_acc': avg_ep_acc})# 反向传播与优化optimizer.zero_grad()  # 清空梯度loss.backward()        # 计算梯度optimizer.step()       # 更新参数torch.cuda.empty_cache()  # 清理GPU缓存(防止内存泄漏)# 每100步打印日志并更新Visdomif (optimizer.global_step) % 100 == 0 and (optimizer.global_step) != 0:timer.update_progress(optimizer.global_step / max_step)  # 更新剩余时间估计print('Epoch:%2d' % (ep),'Iter:%5d/%5d' % (optimizer.global_step, max_step),'Loss:%.4f' % (avg_meter.get('loss')),'avg_ep_EM:%.4f' % (avg_meter.get('avg_ep_EM')),'avg_ep_acc:%.4f' % (avg_meter.get('avg_ep_acc')),'lr: %.4f' % (optimizer.param_groups[0]['lr']), 'Fin:%s' % (timer.str_est_finish()),flush=True)# 更新Visdom图表viz.line([avg_meter.pop('loss')],[optimizer.global_step],win='loss',update='append',opts=dict(title='loss'))# 同理更新 'Acc_exact' 和 'Acc' 图表...# 每epoch后调整PDA的gama参数if model.gama > 0.65:model.gama = model.gama * 0.98  # 逐步衰减注意力强度print('Gama of progressive dropout attention is: ', model.gama)# 保存最终模型torch.save(model.state_dict(), os.path.join(args.save_folder, 'stage1_checkpoint_trained_on_'+args.dataset+'.pth'))
  • 关键细节

    • 渐进式Dropout注意力(PDA)

      • 前3个epoch禁用(enable_PDA=0),让模型初步学习基础特征。

      • gama 初始值为1,逐渐衰减(gama *= 0.98),控制注意力区域的聚焦程度。

    • 损失函数F.multilabel_soft_margin_loss 结合Sigmoid和交叉熵,适用于多标签分类。

    • 指标计算

      • Exact Match (EM):预测标签与真实标签完全一致的样本比例(严格指标)。

      • IoU式准确率:反映预测与真实标签的重合程度(宽松指标)。

    • Visdom集成:实时可视化损失和准确率曲线,便于监控训练状态。

4. 测试阶段 test_phase

def test_phase(args):# 加载生成CAM的模型变体(Net_CAM)model = getattr(importlib.import_module(args.network), 'Net_CAM')(n_class=args.n_class)model = model.cuda()# 加载训练阶段保存的权重args.weights = os.path.join(args.save_folder, 'stage1_checkpoint_trained_on_'+args.dataset+'.pth')weights_dict = torch.load(args.weights)model.load_state_dict(weights_dict, strict=False)model.eval()  # 设置为评估模式(禁用Dropout和BatchNorm的随机性)# 调用自定义推理函数(评估模型在测试集上的性能)score = infer(model, args.testroot, args.n_class)print(score)  # 输出评估结果(如mAP、IoU等)# 可选:保存最终模型(可能包含CAM生成能力)torch.save(model.state_dict(), ...)
  • 关键细节

    • 模型变体Net_CAM 可能修改了网络结构以输出类别激活图(Class Activation Map)。

    • 评估指标infer 函数内部可能计算mAP(平均精度)、像素级IoU等指标。

    • 严格模式strict=False 允许加载部分权重(如分类头维度不同)。

5. 主函数与参数解析

if __name__ == '__main__':parser = argparse.ArgumentParser()# 训练参数parser.add_argument("--batch_size", default=20, type=int)parser.add_argument("--max_epoches", default=20, type=int)parser.add_argument("--network", default="network.resnet38_cls", type=str)parser.add_argument("--lr", default=0.01, type=float)parser.add_argument("--num_workers", default=10, type=int)parser.add_argument("--wt_dec", default=5e-4, type=float)  # 权重衰减(L2正则化)# 实验命名与可视化parser.add_argument("--session_name", default="Stage 1", type=str)  # 实验名称(日志标识)parser.add_argument("--env_name", default="PDA", type=str)          # Visdom环境名parser.add_argument("--model_name", default='PDA', type=str)        # 模型保存名称# 数据集与模型结构parser.add_argument("--n_class", default=4, type=int)               # 类别数(如BCSS为4类)parser.add_argument("--weights", default='init_weights/ilsvrc-cls_rna-a1_cls1000_ep-0001.params', type=str)parser.add_argument("--trainroot", default='datasets/BCSS-WSSS/train/', type=str)parser.add_argument("--testroot", default='datasets/BCSS-WSSS/test/', type=str)parser.add_argument("--save_folder", default='checkpoints/', type=str)# PDA参数parser.add_argument("--init_gama", default=1, type=float)  # 初始注意力强度# 数据集标识parser.add_argument("--dataset", default='bcss', type=str)  # 数据集缩写(影响保存文件名)args = parser.parse_args()train_phase(args)  # 执行训练test_phase(args)   # 执行测试
  • 关键参数说明

    • --network:模型定义文件路径(如 network.resnet38_cls 对应 network/resnet38_cls.py)。

    • --init_gama:PDA的初始强度,影响注意力机制的随机丢弃率。

    • --weights:预训练权重路径(支持MXNet和PyTorch格式)。


http://www.ppmy.cn/embedded/170174.html

相关文章

【数据挖掘]Ndarray数组的创建

在 NumPy 中,ndarray(N-dimensional array)是最核心的数据结构,创建 ndarray 数组的方式有多种,主要包括以下几类: 目录 1. 通过列表或元组创建 2. 使用 NumPy 内置的创建函数 (1&#xff0…

详解DeepSeek模型底层原理及和ChatGPT区别点

一、DeepSeek大模型原理 架构基础 DeepSeek基于Transformer架构,Transformer架构主要由编码器和解码器组成,在自然语言处理任务中,通常使用的是Transformer的解码器部分。它的核心是自注意力机制(Self - Attention),这个机制允许模型在处理输入序列时,关注序列中不同位…

DeepSeek集成到VScode工具,让编程更高效

DeepSeek与VScode的强强联合,为编程效率树立了新标杆。 DeepSeek,一款卓越的代码搜索引擎,以其精准的索引和高速的检索能力,助力开发者在浩瀚的代码海洋中迅速定位关键信息。 集成至VScode后,开发者无需离开熟悉的编辑…

Excel文件中物件PPT文档如何保存到本地

以下是Excel中嵌入的PPT文档保存到本地的详细方法,综合了多种适用场景的解决方案: 方法一:直接通过对象功能另存为 定位嵌入的PPT对象 在Excel中双击打开嵌入的PPT文档,进入编辑模式后,右键点击PPT对象边框&#xff0…

【分布式】Hadoop完全分布式的搭建(零基础)

Hadoop完全分布式的搭建 环境准备: (1)VMware Workstation Pro17(其他也可) (2)Centos7 (3)FinalShell (一)模型机配置 0****)安…

【数据分析】上市公司市场势力数据测算+dofile(1992-2023年)

市场势力通常指的是公司在市场中的相对竞争力和定价能力。具有较强市场势力的公司通常能够控制价格、影响市场规则,并在竞争中占据主导地位。A股公司市场势力数据是对中国资本市场中公司竞争力的深入分析,A股市场中,公司市场势力的强弱不仅影…

HarmonyOS学习第13天:布局进阶,从嵌套到优化

布局嵌套初体验 在 HarmonyOS 应用开发中,布局嵌套是构建复杂界面的重要手段。就像搭建一座高楼,布局嵌套能让各个界面元素有序组合,构建出功能丰富、层次分明的用户界面。我们以日常使用的电商 APP 为例,在商品展示区&#xff0c…

软件工程应试复习(考试折磨版)

针对学校软件工程考试,参考教材《软件工程导论(第6版)》1-8章 学习的艺术:不断地尝试,我一定会找到高效用的方法,让学习变成一门艺术,从应试备考中解救出我的时间同胞们。 好嘞!既然…