从零学习大模型(十一)-----Lottery Ticket Hypothesis剪枝

ops/2024/10/31 3:06:38/

Lottery Ticket Hypothesis(LTH)是由 Frankle 和 Carbin 在 2019 年提出的一种剪枝方法,其核心思想是神经网络中存在可以单独训练的小型子网络(即"中奖票"),这些子网络可以在保持原始模型性能的情况下有效地训练。通过找到这些子网络,我们可以实现大模型的剪枝,从而减少模型的计算复杂度和存储需求。

实现过程

  1. 初始训练
    • 对于一个大型神经网络,首先对其进行完全训练,得到一个经过充分训练的基准模型。
    • 在此阶段,所有权重都将参与训练,并且模型逐渐逼近最优状态。
  2. 权重重要性评估与剪枝
    • 对训练后的模型,使用权重的重要性度量方法(如权重的绝对值大小)来评估每个权重在模型中的贡献。权重的重要性度量是基于这样一个假设:权重的绝对值越大,其对模型预测的贡献就越大。
    • 权重绝对值大小:在神经网络中,权重的绝对值大小可以用来衡量其对输出的影响程度。通常情况下,较大的权重对神经元的激活产生更显著的影响,因此对最终的预测结果也具有更大的贡献。反之,绝对值较小的权重对输出的影响较小,可以被认为是冗余的。
    • 具体步骤
      1. 计算权重的绝对值:对于每个神经网络层中的权重,计算其绝对值。
      2. 排序和选择:根据权重的绝对值大小进行排序,将绝对值较小的权重标记为不重要。
      3. 剪枝:剪去这些不重要的权重,使模型变得更加稀疏。
    • 剪枝后会得到一个稀疏子网络,这个子网络保留了大部分重要的连接,同时大大减少了参数数量。
  3. 重置权重和再训练
    • 剪枝后的子网络的权重重置为它们在初始随机化时的值。这个步骤的目的是希望子网络能够独立训练,而不是依赖于剪枝前的已训练权重。
    • 通过将权重重置为初始状态,可以验证这些被称为"中奖票"的子网络是否具有足够的表达能力,能够单独训练达到与原始大模型相似的性能。
  4. 迭代剪枝
    • 剪枝后的子网络进行再训练,直到它能够达到与原始模型相近的性能。
    • 如果目标是进一步减少模型大小,可以多次进行剪枝和再训练的过程,直到达到所需的压缩比例。
    • 每次迭代剪枝都会进一步减少不重要的权重,逐步形成一个稀疏、可训练的小型子网络。

代码实现Lottery Ticket Hypothesis的剪枝全过程

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import copy
from tqdm import tqdm# 检查是否可以使用GPU(针对MacBook M3芯片)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")# 定义一个简单的神经网络
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.fc1 = nn.Linear(64 * 16 * 16, 256)self.fc2 = nn.Linear(256, 128)self.fc3 = nn.Linear(128, 10)def forward(self, x):x = torch.relu(self.conv1(x))x = torch.relu(self.conv2(x))x = torch.max_pool2d(x, 2)x = x.view(x.size(0), -1)  # 修改view以确保batch size匹配x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))x = self.fc3(x)return x# 初始化模型
model = SimpleNet().to(device)
initial_state_dict = copy.deepcopy(model.state_dict())  # 保存初始权重# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)# 数据预处理和数据加载器
transform = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomCrop(32, padding=4),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])trainset = torchvision.datasets.CIFAR10(root='../datasets', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)# 模型初始训练
def train(model, optimizer, criterion, dataloader, epochs=5):model.train()for epoch in range(epochs):running_loss = 0.0for i, (inputs, labels) in enumerate(dataloader, 0):inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()# 在更新之前再次应用掩码,确保剪枝的权重不会被更新apply_pruning_mask(model)optimizer.step()running_loss += loss.item()if i % 100 == 99:  # 每100个批次打印一次损失print(f'Epoch [{epoch + 1}], Step [{i + 1}], Loss: {running_loss / 100:.4f}')running_loss = 0.0# 剪枝函数
def prune_by_magnitude(model, amount=0.2):print("Starting prune_by_magnitude...")# 计算每个参数的绝对值并排序all_weights = []for param in model.parameters():if len(param.data.size()) != 1:  # 忽略偏置项all_weights.extend(param.cpu().data.abs().numpy().flatten())threshold = torch.tensor(sorted(all_weights)[int(len(all_weights) * amount)])# 根据阈值剪枝,并保存掩码with torch.no_grad():for name, param in tqdm(list(model.named_parameters()), desc="Applying pruning mask", ncols=100, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]'):if "weight" in name:mask = (torch.abs(param) > threshold).float().to(device)param.mul_(mask)model.register_buffer(f"mask_{name.replace('.', '_')}", mask)  # 保存掩码用于冻结权重# 重置权重函数
def reset_weights(model, initial_state_dict):state_dict = {k: v for k, v in initial_state_dict.items() if "mask" not in k}model.load_state_dict(state_dict, strict=False)# 修改优化器以冻结被剪枝的权重
def apply_pruning_mask(model):with torch.no_grad():for name, param in model.named_parameters():if "weight" in name and hasattr(model, f"mask_{name.replace('.', '_')}"):mask = getattr(model, f"mask_{name.replace('.', '_')}")param.mul_(mask)  # 确保被剪枝的权重保持为零if __name__ == "__main__":# 初始训练train(model, optimizer, criterion, trainloader, epochs=5)# 剪枝并重置权重prune_by_magnitude(model, amount=0.2)reset_weights(model, initial_state_dict)# 再次训练前应用剪枝掩码,以确保被剪枝的权重保持为零apply_pruning_mask(model)# 再次训练train(model, optimizer, criterion, trainloader, epochs=5)# 迭代剪枝和再训练的过程可以继续进行,直到达到所需的压缩比例

优点

  1. 高效压缩:LTH 方法可以找到一个非常稀疏的子网络,使得模型的计算量和存储需求大幅降低,同时性能基本不受影响。
  2. 理论支持:LTH 提出了一个关于神经网络可训练性的理论假设,即在大型神经网络中存在一个子网络(中奖票),如果将其权重重置为初始值并独立训练,这个子网络可以达到与原始模型相近的性能。具体来说,LTH 假设在初始随机权重中已经存在一个可以有效训练的稀疏子网络,这个子网络在训练时具备足够的表示能力和学习能力。因此,通过找到这个子网络并重置其权重,可以在保持模型性能的前提下减少不必要的参数,从而实现模型的压缩。
  3. 适用于多种架构:这种方法可以应用于不同类型的神经网络架构,包括卷积神经网络(CNN)和 Transformer 等。

缺点

  1. 计算开销大:LTH 方法需要多次反复地训练、剪枝和重置权重,因此训练过程相对耗时且计算资源需求较高。
  2. 剪枝策略依赖于初始权重剪枝后的模型性能与初始权重的选择关系密切,存在一定的随机性,这可能导致最终的剪枝效果不稳定。

应用场景

  • 移动设备和嵌入式系统:LTH 可以用于在内存和计算能力有限的设备上部署深度学习模型,例如移动设备、边缘计算设备等,通过找到稀疏子网络来实现模型压缩和加速。
  • 加速推理:对于需要实时推理的应用,剪枝后的稀疏子网络可以减少计算量,从而加速模型推理。

相关文献

  • Frankle, J., & Carbin, M. (2019). The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks. ICLR 2019.

http://www.ppmy.cn/ops/129764.html

相关文章

给哔哩哔哩bilibili电脑版做个手机遥控器

前言 bilibili电脑版可以在电脑屏幕上观看bilibili视频。然而&#xff0c;电脑版的bilibili不能通过手机控制视频翻页和调节音量&#xff0c;这意味着观看视频时需要一直坐在电脑旁边。那么&#xff0c;有没有办法制作一个手机遥控器来控制bilibili电脑版呢&#xff1f; 首先…

logdata-anomaly-miner:一款安全日志解析与异常检测工具

关于logdata-anomaly-miner logdata-anomaly-miner是一款安全日志解析与异常检测工具&#xff0c;该工具旨在以有限的资源和尽可能低的权限运行分析&#xff0c;以使其适合生产服务器使用。 为了确保 logdata-anomaly-miner的正常运行&#xff0c;推荐安装了python > 3.6的…

ubuntu 22.04网线连接无ip、网络设置无有线网界面(netplan修复)

目前遇到过树莓派和其他设备安装 ubuntu22.04&#xff0c; 使用有线网络一段时间&#xff08;可能有其他软件安装导致&#xff09;造成有线网络未启动无ip分配的问题。 1、动态分配 通过命令行启动dhcpclient实现 网络eth0存在异常&#xff0c;网口灯电源和信号灯均点亮&am…

唤醒车机时娱乐屏出现黑屏,卡顿的案例分享

1. 背景 测试在正常操作车机的时候&#xff0c;出现了&#xff1a;唤醒车机时娱乐屏出现连续两次黑屏&#xff0c;且发现系统有卡顿的现象。 2. log分析 low_memory_killer 杀掉adj100的visible进程&#xff0c;是因为memory 不足导致的。 行 16839: 10-26 16:22:11.900235 …

STL---map与set前言(红黑树)

文章目录 红黑树概念性质AVL树和红黑树的对比红黑树的代码实现节点结构红黑树的结构红黑树的插入逻辑红黑树的插入的代码实现其他接口验证红黑树的正确性红黑树完整代码 红黑树 概念 红黑树是一种搜索二叉树&#xff0c;但红黑树在每个节点上增加一个存储位表示节点的颜色&am…

UART-通用异步收发器

1. UART的基本工作原理 UART通信主要有两个部分构成&#xff1a;发送器和接收器&#xff0c;也就是我们常见的&#xff08;RX接收&#xff0c;TX发送&#xff09;两个独立的线路来实现数据的双向传输&#xff0c;由于是异步的&#xff0c;UART并不需要时钟信号&#xff0c;而是…

VS中MFC的使用-学习笔记

IsWindow 函数 (winuser.h):确定指定的窗口句柄是否标识现有窗口。使用Visual C从文件读取XML数据 tinyxml2库 CMarkup类源代码下载地址 使用方法&#xff0c;将头文件和源文件拷贝添加到工程中即可&#xff0c;若编译错误&#xff0c;在源文件中添加#include “stdafx.h” 使用…

Java面试题——微服务篇

1.微服务的拆分原则/怎么样才算一个有效拆分 单一职责原则&#xff1a;每个微服务应该具有单一的责任。这意味着每个服务只关注于完成一项功能&#xff0c;并且该功能应该是独立且完整的。最小化通信&#xff1a;尽量减少服务之间的通信&#xff0c;服务间通信越少&#xff0c…