yolov11剪枝

news/2024/12/2 11:30:23/

思路:yolov11中的C3k2与yolov8的c2f的不同,所以与之前yolov8剪枝有稍许不同;

后续:会将剪枝流程写全,以及增加蒸馏、注意力、改loss;

注意:

1.在代码105行修改pruning.get_threshold(yolo.model, 0.65),可以获得不同的剪枝率;

2.改代码放在训练代码同一页面下即可;

3.在最后修改文件夹地址来获得剪枝后的模型;

python">from ultralytics import YOLO
import torch
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect, C3k2
from torch.nn.modules.container import Sequential
import os# os.environ["CUDA_VISIBLE_DEVICES"] = "2"class PRUNE():def __init__(self) -> None:self.threshold = Nonedef get_threshold(self, model, factor=0.8):ws = []bs = []for name, m in model.named_modules():if isinstance(m, torch.nn.BatchNorm2d):w = m.weight.abs().detach()b = m.bias.abs().detach()ws.append(w)bs.append(b)print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())print()# keepws = torch.cat(ws)self.threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]def prune_conv(self, conv1: Conv, conv2: Conv):## Normal Pruninggamma = conv1.bn.weight.data.detach()beta = conv1.bn.bias.data.detach()keep_idxs = []local_threshold = self.thresholdwhile len(keep_idxs) < 8:  ## 若剩余卷积核<8, 则降低阈值重新筛选keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]local_threshold = local_threshold * 0.5n = len(keep_idxs)# n = max(int(len(idxs) * 0.8), p)print(n / len(gamma) * 100)conv1.bn.weight.data = gamma[keep_idxs]conv1.bn.bias.data = beta[keep_idxs]conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]conv1.bn.num_features = nconv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]conv1.conv.out_channels = nif isinstance(conv2, list) and len(conv2) > 3 and conv2[-1]._get_name() == "Proto":proto = conv2.pop()proto.cv1.conv.in_channels = nproto.cv1.conv.weight.data = proto.cv1.conv.weight.data[:, keep_idxs]if conv1.conv.bias is not None:conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]## Regular Pruningif not isinstance(conv2, list):conv2 = [conv2]for item in conv2:if item is None: continueif isinstance(item, Conv):conv = item.convelse:conv = itemif isinstance(item, Sequential):conv1 = item[0]conv = item[1].convconv1.conv.in_channels = nconv1.conv.out_channels = nconv1.conv.groups = nconv1.conv.weight.data = conv1.conv.weight.data[keep_idxs, :]conv1.bn.bias.data = conv1.bn.bias.data[keep_idxs]conv1.bn.weight.data = conv1.bn.weight.data[keep_idxs]conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]conv1.bn.num_features = nconv.in_channels = nconv.weight.data = conv.weight.data[:, keep_idxs]def prune(self, m1, m2):if isinstance(m1, C3k2):  # C3k2 as a top convm1 = m1.cv2if isinstance(m1, Sequential):m1 = m1[1]if not isinstance(m2, list):  # m2 is just one modulem2 = [m2]for i, item in enumerate(m2):if isinstance(item, C3k2) or isinstance(item, SPPF):m2[i] = item.cv1self.prune_conv(m1, m2)def do_pruning(modelpath, savepath):pruning = PRUNE()### 0. 加载模型yolo = YOLO(modelpath)  # build a new model from scratchpruning.get_threshold(yolo.model, 0.65)  # 这里的0.8为剪枝率。### 1. 剪枝C3k2 中的Bottleneckfor name, m in yolo.model.named_modules():if isinstance(m, Bottleneck):pruning.prune_conv(m.cv1, m.cv2)### 2. 指定剪枝不同模块之间的卷积核seq = yolo.model.modelfor i in [3, 5, 7, 8]:pruning.prune(seq[i], seq[i + 1])### 3. 对检测头进行剪枝# 在P3层: seq[15]之后的网络节点与其相连的有 seq[16]、detect.cv2[0] (box分支)、detect.cv3[0] (class分支)# 在P4层: seq[18]之后的网络节点与其相连的有 seq[19]、detect.cv2[1] 、detect.cv3[1]# 在P5层: seq[21]之后的网络节点与其相连的有 detect.cv2[2] 、detect.cv3[2]detect: Detect = seq[-1]proto = detect.protolast_inputs = [seq[16], seq[19], seq[22]]colasts = [seq[17], seq[20], None]for idx, (last_input, colast, cv2, cv3, cv4) in enumerate(zip(last_inputs, colasts, detect.cv2, detect.cv3, detect.cv4)):if idx == 0:pruning.prune(last_input, [colast, cv2[0], cv3[0], cv4[0], proto])else:pruning.prune(last_input, [colast, cv2[0], cv3[0], cv4[0]])pruning.prune(cv2[0], cv2[1])pruning.prune(cv2[1], cv2[2])pruning.prune(cv3[0], cv3[1])pruning.prune(cv3[1], cv3[2])pruning.prune(cv4[0], cv4[1])pruning.prune(cv4[1], cv4[2])### 4. 模型梯度设置与保存for name, p in yolo.model.named_parameters():p.requires_grad = Trueyolo.val(data='data.yaml', batch=2, device=0, workers=0)torch.save(yolo.ckpt, savepath)if __name__ == "__main__":modelpath = "runs/segment/Constraint/weights/best.pt"savepath = "runs/segment/Constraint/weights/last_prune.pt"do_pruning(modelpath, savepath)


http://www.ppmy.cn/news/1551739.html

相关文章

Qt 中的 UiTools 详解

Qt 是一个功能强大的 C 跨平台开发框架&#xff0c;支持用户界面设计、图形渲染、事件处理等诸多功能。UiTools 是 Qt 提供的一个模块&#xff0c;专门用于动态加载和处理 .ui 文件。它在动态界面生成、模板化设计等场景下尤为重要。 一、什么是 UiTools&#xff1f; UiTools …

javaweb 前端 vue3

vue快速入门 引入createAPP这个模块&#xff0c;或者说这个函数 第二步&#xff0c;创建应用实例 调用createAPP这个函数&#xff0c;传递对象{}&#xff0c;js中定义对象用{} js的分号可以加或者不加 第四步准备数据&#xff0c;在传递的对象中声明这个方法data&#xff0c;指…

七:仪表盘安装-controller node

一&#xff1a;工具、环境准备-controller node 二&#xff1a;OpenStack环境准备-controller node 三&#xff1a;安装服务-controller node 四&#xff1a;工具、环境准备-compute node 五&#xff1a;OpenStack环境准备-compute node 六&#xff1a;安装服务-compute node 七…

【JAVA】Java高级:连接池的使用与性能优化——C3P0、HikariCP与DBCP比较

在Java开发中&#xff0c;数据库连接池帮助我们有效地管理数据库连接&#xff0c;减少连接的创建和销毁所带来的开销&#xff0c;从而提高应用程序的性能和可伸缩性。常用的数据库连接池有C3P0、HikariCP和DBCP。接下来&#xff0c;我们将逐步深入了解这三种连接池的特点、优缺…

【学术投稿】Imagen:重塑图像生成领域的革命性突破

【连续七届已快稳ei检索】第八届电子信息技术与计算机工程国际学术会议&#xff08;EITCE 2024&#xff09;_艾思科蓝_学术一站式服务平台 更多学术会议请看 https://ais.cn/u/nuyAF3 目录 引言 一、Imagen模型的技术原理 1. 模型概述 2. 工作流程 3. 技术创新 二、Ima…

浅谈网络 | 应用层之DNS协议

目录 DNS 服务器的工作原理DNS 解析流程负载均衡示例&#xff1a;DNS 访问数据中心中对象存储上的静态资源 随着互联网的普及&#xff0c;网站的数量越来越多&#xff0c;常用的网站也有二三十个。如果我们全部用 IP 地址来访问网站&#xff0c;恐怕很难记住。于是&#xff0c;…

LeetCode 动态规划 爬楼梯

爬楼梯 假设你正在爬楼梯。需要 n 阶你才能到达楼顶。 每次你可以爬 1 或 2 个台阶。你有多少种不同的方法可以爬到楼顶呢&#xff1f; 示例 1&#xff1a; 输入&#xff1a;n 2 输出&#xff1a;2 解释&#xff1a;有两种方法可以爬到楼顶。 1 阶 1 阶 2 阶 示例 2&#xff…

追寻红色足迹,领略西湖古韵今风|中共杭州美创科技有限公司支部党建活动纪实

11月23日&#xff0c;为深入推进党员思想政治教育&#xff0c;大力弘扬红色文化&#xff0c;传承革命先辈不朽精神&#xff0c;中共杭州美创科技有限公司支部于精心组织了一场主题为“追寻红色足迹&#xff0c;领略西湖古韵今风”的党建活动。此次活动以实地学习与亲身体验相结…