剪枝与重参第七课:YOLOv8剪枝

news/2024/12/21 22:23:23/

目录

  • YOLOv8剪枝
    • 前言
    • 1.Overview
    • 2.Pretrain(option)
    • 3.Constrained Training
    • 4.Prune
      • 4.1 检查BN层的bias
      • 4.2 设置阈值和剪枝率
      • 4.3 最小剪枝Conv单元的TopConv
      • 4.4 最小剪枝Conv单元的BottomConv
      • 4.5 Seq剪枝
      • 4.6 Detect-FPN剪枝
      • 4.7 完整示例代码
    • 5.YOLOv8剪枝总结
    • 总结

YOLOv8剪枝

前言

手写AI推出的全新模型剪枝与重参课程。记录下个人学习笔记,仅供自己参考。

本次课程主要讲解YOLOv8剪枝。

课程大纲可看下面的思维导图

在这里插入图片描述

1.Overview

YOLOV8剪枝的流程如下:

在这里插入图片描述

结论:在VOC2007上使用yolov8s模型进行的实验显示,预训练和约束训练在迭代50个epoch后达到了相同的mAP(:0.5)值,约为0.77。剪枝后,微调阶段需要65个epoch才能达到相同的mAP50。修建后的ONNX模型大小从43M减少到36M。

注意:我们需要将网络结构和网络权重区分开来,YOLOv8的网络结构来自yaml文件,如果我们进行剪枝后保存的权重文件的结构其实是和原始的yaml文件不符合的,需要对yaml文件进行修改满足我们的要求。

2.Pretrain(option)

步骤如下:

  • git clone https://github.com/ultralytics/ultralytics.git
  • use VOC2007, and modify the VOC.yaml(去除VOC2012的相关内容)
  • disable amp(禁用amp混合精度)
# FILE: ultralytics/yolo/engine/trainer.py
...
def check_amp(model):# Avoid using mixed precision to affect finetunereturn False # <============== modified(修改部分)device = next(model.parameters()).device  # get model deviceif device.type in ('cpu', 'mps'):return False  # AMP only used on CUDA devicesdef amp_allclose(m, im):# All close FP32 vs AMP results...

3.Constrained Training

约束训练是为了筛选哪些channel比较重要,哪些channel没有那么重要,也就是我们上节课所说的稀疏训练

  • prune the BN layer by adding L1 regularizer.
# FILE: ultralytics/yolo/engine/trainer.py
...
# Backward
self.scaler.scale(self.loss).backward()# <============ added(新增)
l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)
for k, m in self.model.named_modules():if isinstance(m, nn.BatchNorm2d):m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
if ni - last_opt_step >= self.accumulate:self.optimizer_step()last_opt_step = ni
...

注意1:在剪枝时,我们选择加载last.pt而非best.pt,因为由于迁移学习,模型的泛化性比较好,在第一个epoch时mAP值最大,但这并不是真实的,我们需要稳定下来的一个模型进行prune

注意2:我们在对Conv层进行剪枝时,我们只考虑1v1(如BottleNeck,C2f and SPPF)、1vm(如Backbone,Detect)的情形,并不考虑mv1的情形。

思考:Constrained Training的必要性?

约束训练可以使得模型更易于剪枝。在约束训练中,模型会学习到一些通道或者权重系数比较不重要的信息,而这些信息在剪枝过程中得到应用,从而达到模型压缩的效果。而如果直接进行剪枝操作,可能会出现一些问题,比如剪枝后的模型精度大幅下降、剪枝不均匀等。因此,在进行剪枝操作之前,通过稀疏训练的方式,可以更好地准确地确定哪些通道或者权重系数可以被剪掉,从而避免上述问题的发生。

4.Prune

4.1 检查BN层的bias

  • 剪枝后,确保BN层的大部分bias足够小(接近于0),否则重新进行稀疏训练
for name, m in model.named_modules():if isinstance(m, torch.nn.BatchNorm2d):w = m.weight.abs().detach()b = m.bias.abs().detach()ws.append(w)bs.append(b)print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())

4.2 设置阈值和剪枝率

  • threshold:全局或局部
  • factor:保持率,裁剪太多不推荐
factor = 0.8
ws = torch.cat(ws)
threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]
print(threshold)

4.3 最小剪枝Conv单元的TopConv

Top-Bottom Conv如下图所示:

在这里插入图片描述

TopConv剪枝的示例代码如下:

def prune_conv(conv1: Conv, conv2: Conv):gamma = conv1.bn.weight.data.detach()beta  = conv1.bn.bias.data.detach()keep_idxs = []    local_threshold = thresholdwhile len(keep_idxs) < 8:keep_idxs = torch.where(gamma.abs() >= local_threshold)[0] local_threshold = local_threshold * 0.5n = len(keep_idxs)print(n / len(gamma) * 100)  # 打印我们保留了多少的channel# pruneconv1.bn.weight.data = gamma[keep_idxs]conv1.bn.bias.data   = beta[keep_idxs]conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]conv1.bn.running_var.data  = conv1.bn.running_var.data[keep_idxs]conv1.bn.num_features   = nconv1.conv.weight.data  = conv1.conv.weight.data[keep_idxs]conv1.conv.out_channels = nif conv1.conv.bias is not None:conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]# pattern to prune
# 1. prune all 1 vs 1 TB pattern e.g. bottleneck
for name, m in model.named_modules():if isinstance(m, Bottleneck):prune_conv(m.cv1, m.cv2)

注意:由于NVIDIA的硬件加速的原因,我们保留的channels应该大于等于8,我们可以通过设置local_threshold,尽量小点,让更多的channel保留下来。

4.4 最小剪枝Conv单元的BottomConv

BottomConv剪枝的示例代码如下:

def prune_conv(conv1: Conv, conv2: Conv):...if not isinstance(conv2, list):conv2 = [conv2]for item in conv2:if item is not None:if isinstance(item, Conv):conv = item.convelse:conv = itemconv.in_channels = nconv.weight.data = conv.weight.data[:, keep_idxs]

注意BottomConv存在两种情形,一种是单个Conv,还有一种是Conv列表。TopConv需要考虑conv2d+bn+act,而BottomConv只需要考虑conv2d

4.5 Seq剪枝

Seq剪枝的示例代码如下:

def prune(m1, m2):if isinstance(m1, C2f):m1 = m1.cv2if not isinstance(m2, list):m2 = [m2]for i, item in enumerate(m2):if isinstance(item, C2f) or isinstance(item, SPPF):m2[i] = item.cv1prune_conv(m1, m2)# 2. prune sequential
seq = model.model
for i in range(3, 9):if i in [6, 4, 9]: continueprune(seq[i], seq[i+1])

注意:我们不考虑1vm的情形,因此在4,6,9的module我们是不进行剪枝的,后续head进行Concat时是对4,6,9的module进行拼接的。考虑到前几个conv的特征提取的重要性,我们也不剪枝它们。(那感觉没剪几个module呀😂)

4.6 Detect-FPN剪枝

Detect-FPN剪枝的示例代码如下:

# 3. prune FPN related block
detect: Detect = seq[-1]last_inputs = [seq[15], seq[18], seq[21]]
colasts     = [seq[16], seq[19], None]for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):prune(last_input, [colast, cv2[0], cv3[0]])prune(cv2[0], cv2[1])prune(cv2[1], cv2[2])prune(cv3[0], cv3[1])prune(cv3[1], cv3[2])for name, p in yolo.model.named_parameters():p.requires_grad = True

注意:一定要设置所有参数为需要训练。因为加载后的model会给弄成False,导致报错

4.7 完整示例代码

完整的示例代码如下:

from ultralytics import YOLO
import torch
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect# Load a model
yolo = YOLO("epoch-50-constrained_weights/last.pt")  # build a new model from scratch
model = yolo.modelws = []
bs = []for name, m in model.named_modules():if isinstance(m, torch.nn.BatchNorm2d):w = m.weight.abs().detach()b = m.bias.abs().detach()ws.append(w)bs.append(b)print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())
# keep
factor = 0.8
ws = torch.cat(ws)
threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]
print(threshold)def prune_conv(conv1: Conv, conv2: Conv):gamma = conv1.bn.weight.data.detach()beta  = conv1.bn.bias.data.detach()# if gamma.abs().min() > f or beta.abs().min() > 0.1:#     return# idxs = torch.argsort(gamma.abs() * coeff + beta.abs(), descending=True)keep_idxs = []local_threshold = thresholdwhile len(keep_idxs) < 8:keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]local_threshold = local_threshold * 0.5n = len(keep_idxs)# n = max(int(len(idxs) * 0.8), p)print(n / len(gamma) * 100)# scale = len(idxs) / nconv1.bn.weight.data = gamma[keep_idxs]conv1.bn.bias.data   = beta[keep_idxs]conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]conv1.bn.num_features = nconv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]conv1.conv.out_channels = nif conv1.conv.bias is not None:conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]if not isinstance(conv2, list):conv2 = [conv2]for item in conv2:if item is not None:if isinstance(item, Conv):conv = item.convelse:conv = itemconv.in_channels = nconv.weight.data = conv.weight.data[:, keep_idxs]def prune(m1, m2):if isinstance(m1, C2f):      # C2f as a top convm1 = m1.cv2if not isinstance(m2, list): # m2 is just one modulem2 = [m2]for i, item in enumerate(m2):if isinstance(item, C2f) or isinstance(item, SPPF):m2[i] = item.cv1prune_conv(m1, m2)for name, m in model.named_modules():if isinstance(m, Bottleneck):prune_conv(m.cv1, m.cv2)seq = model.model
for i in range(3, 9):if i in [6, 4, 9]: continueprune(seq[i], seq[i+1])detect:Detect = seq[-1]
last_inputs   = [seq[15], seq[18], seq[21]]
colasts       = [seq[16], seq[19], None]
for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):prune(last_input, [colast, cv2[0], cv3[0]])prune(cv2[0], cv2[1])prune(cv2[1], cv2[2])prune(cv3[0], cv3[1])prune(cv3[1], cv3[2])# ***step4,一定要设置所有参数为需要训练。因为加载后的model他会给弄成false。导致报错
# pipeline:
# 1. 为模型的BN增加L1约束,lambda用1e-2左右
# 2. 剪枝模型,比如用全局阈值
# 3. finetune,一定要注意,此时需要去掉L1约束。最终final的版本一定是去掉的
for name, p in yolo.model.named_parameters():p.requires_grad = True# 1. 不能剪枝的layer,其实可以不用约束
# 2. 对于低于全局阈值的,可以删掉整个module
# 3. keep channels,对于保留的channels,他应该能整除n才是最合适的,否则硬件加速比较差
#    n怎么选,一般fp16时,n为8
#                int8时,n为16
#    cp.async.cg.shared
#yolo.val()
# yolo.export(format="onnx")
# yolo.train(data="VOC.yaml", epochs=100)
print("done")

5.YOLOv8剪枝总结

关于yolov8剪枝有以下几点值得注意:

Pipeline:

    1. 为模型的BN增加L1约束,lambda用1e-2左右
    1. 剪枝模型使用的是全局阈值
    1. finetune模型时,一定要注意,此时需要去掉L1约束,最终的final的版本一定是去掉的(ultralytics/yolo/engine/trainer.py中注释)
    1. 对于yolo.model.named_parameters()循环,需要设置p.requires_gradTrue

Future work:

    1. 不能剪枝的layer,其实可以不用约束
    1. 对于低于全局阈值的,可以删掉整个module
    1. keep channels,对于保留的channels,它应该能整除n才是最合适的,否则硬件加速比较差
  • n怎么选呢?一般fp16时,n为8;int8时,n为16

总结

本次课程学习了YOLOv8的剪枝,主要是对前面剪枝课程的一个总结和实现吧,大体流程就是稀疏训练后进行剪枝最后微调,看着虽然简单,但实际细节把控还是非常多的,比如说哪些部分好剪,哪些部分不好剪,剪枝的过程中如何通过model获取想要prune的module等等,需要对YOLOv8整体网络结构和对ONNX模型的操作非常熟练,这还只是基础理论,实操部分的坑还没踩呢,在之后好好练习练习吧😄


http://www.ppmy.cn/news/47027.html

相关文章

你真的会用iPad吗,如何使iPad秒变生产力工具?在iPad上用vscode写代码搞开发

目录 前言 视频教程 1. 本地环境配置 2. 内网穿透 2.1 安装cpolar内网穿透(支持一键自动安装脚本) 2.2 创建HTTP隧道 3. 测试远程访问 4. 配置固定二级子域名 4.1 保留二级子域名 4.2 配置二级子域名 5. 测试使用固定二级子域名远程访问 6. iPad通过软件远程vscode…

Java的时代依然还在,合格的Java工程师成为紧缺人才

Java的时代依然还在&#xff0c;合格的Java工程师成为紧缺人才 编程语言的世界变化莫测&#xff0c;在其中浮浮沉沉28年的Java&#xff0c;也经历见证了很多语言的兴起和衰败。在最新的编程语言排行榜中&#xff0c;Java依旧位居前三&#xff0c;可见Java的发展后劲有多强&…

C++linux高并发服务器项目实践 day3

Clinux高并发服务器项目实践 day3 文件IO标准C库IO函数与LinuxIO函数虚拟地址空间文件描述符Linux系统IO函数open与close mode:八进制的数&#xff0c;表示用户对创建出的新的文件的操作权限 最终的权限是&#xff1a;mode & ~umask 0777 r(读) w(写) x(可执行)都有这样的权…

Linux 的 grep 命令使用大全

当你需要在Linux或Unix系统中快速搜索文件中的特定字符串时&#xff0c;grep命令是非常有用的工具。Grep是Global Regular Expression Print的缩写&#xff0c;它是一个文本搜索工具&#xff0c;可以用来在一个或多个文件中查找文本模式。在这篇博客中&#xff0c;我将会讲解gr…

FPGA基于SFP光口实现1G千兆网UDP通信 1G/2.5G Ethernet PCS/PMA or SGMII替代网络PHY芯片 提供工程源码和技术支持

目录 1、前言2、我这里已有的UDP方案3、详细设计方案4、vivado工程详解5、上板调试验证并演示6、福利&#xff1a;工程代码的获取 1、前言 目前网上的fpga实现udp基本生态如下&#xff1a; 1&#xff1a;verilog编写的udp收发器&#xff0c;但不带ping功能&#xff0c;这样的代…

Spark 实现重新分区 partitionBy、coalesce、repartition(附代码演示)

文章目录 1、partitionBy 源码中的定义&#xff08;部分&#xff09; 调用方式 2、coalesce 源码中的定义 调用方式 3、repartition 源码中的定义 调用方式 repartition和coalesce的区别 代码演示 &#xff08;跳转代码&#xff09; 实现重新分区&#xff0c;本质上…

C++ [图论算法详解] 欧拉路欧拉回路

蒟蒻还在上课&#xff0c;所以文章更新的实在慢了点 那今天就来写一篇这周刚学的欧拉路和欧拉回路吧 讲故事环节&#xff1a; 在 一个风雪交加的夜晚 18世纪初普鲁士的哥尼斯堡&#xff0c;有一条河穿过&#xff0c;河上有两个小岛&#xff0c;有七座桥把两个岛与河岸联系…

linux知识

1.vi 删除-dd i-insert 最后一行-G 第一行-g 查找-/ 替换-:s/old/new/g 2.wc -》 行数 字符数 字节数 -w 统计字数 3. sort -k 按某一列排序 -r reverse -n 按字符排 4.uniq -c 统计重复数量 5.head -4 取文件前4行 6.date --date"1 days ago" date "%Y%m%D %H…