DepGraph: Towards Any Structural Pruning
Abstract
- 结构剪枝通过从神经网络中删除结构分组的参数来实现模型加速。然而,参数分组模式在不同模型之间差异很大,因此依赖于手动设计的分组方案的架构特定剪枝器无法推广到新的架构。在这项工作中,我们研究了一项极具挑战性但几乎没有被探索过的任务——任意结构剪枝,以解决 CNN、RNN、GNN 和 Transformers 等任意架构的一般结构剪枝。实现这一目标的最突出障碍在于结构耦合,它不仅迫使不同的层同时被剪枝,而且还期望所有被删除的参数始终不重要,从而避免剪枝后出现结构问题和性能显著下降。为了解决这个问题,我们提出了一种通用的全自动方法——依赖图 (DepGraph),以明确地模拟层之间的依赖关系并全面分组耦合参数以进行剪枝。在这项工作中,我们在几种架构和任务上对我们的方法进行了广泛的评估,包括用于图像的 ResNe(X)t、DenseNet、MobileNet 和 Vision Transformer、用于图形的 GAT、用于 3D 点云的 DGCNN 以及用于语言的 LSTM,并证明,即使使用简单的基于规范的标准,所提出的方法也能始终如一地产生令人满意的性能。项目地址:[GitHub - VainF/Torch-Pruning: CVPR 2023] DepGraph: Towards Any Structural Pruning
- 论文地址:[2301.12900] DepGraph: Towards Any Structural Pruning
- 提出了一种非深度图算法DepGraph,实现了架构通用的结构化剪枝,适用于CNNs, Transformers, RNNs, GNNs,LLM大语言模型等网络。该算法能够自动地分析复杂的结构耦合,从而正确地移除参数实现网络加速。基于DepGraph算法,我们开发了PyTorch结构化剪枝框架 Torch-Pruning。不同于依赖Masking实现的“模拟剪枝”,该框架能够实际地移除参数和通道,降低模型推理成本。在DepGraph的帮助下,研究者和工程师无需再与复杂的网络结构斗智斗勇,可以轻松完成复杂模型的一键剪枝。CVPR 2023 | DepGraph 通用结构化剪枝 - 知乎
- Torch-Pruning(TP)是一个结构化剪枝库,与现有框架(例如torch.nn.utils.prune)最大的区别在于,TP会物理地移除参数,同时自动裁剪其他依赖层。TP是一个纯 PyTorch 项目,实现了内置的计算图的追踪(Tracing)、依赖图(DepenednecyGraph, 见论文)、剪枝器等功能,同时支持 PyTorch 1.x 和 2.0 版本。为了进一步验证分组的有效性,作者在不同的卷积网络上评估了不同的分组策略。策略主要包括:
- 不分组:稀疏学习和重要性评估在单个卷积层上独立进行;
- 仅卷积分组:组内的所有卷积层都以一致的方式稀疏化。
- 完全分组:一个组内的所有可训练层,如卷积、批处理归一化和全连接层,都是一致稀疏的。
- 当忽略神经网络中的分组信息并孤立地稀疏每一层时,本文的方法的性能将显著下降,在某些情况下,甚至由于过度剪枝而崩溃。根据仅conv分组的结果,在组中包含更多的参数有利于最终的性能,但在组中仍然省略了一些有用的信息。通过实现全分组策略,可以进一步提高剪枝的精度。
Introduction
-
近年来,边缘计算应用的兴起呼唤着深度神经压缩的必要性 。在众多网络压缩范例中,剪枝已被证明是一种高效实用的方法。网络剪枝的目标是从给定网络中删除冗余参数,以减轻网络规模并潜在地加快推理速度。主流剪枝方法大致可分为两种方案:结构剪枝 和非结构剪枝。两者的核心区别在于,结构剪枝通过物理移除分组参数来改变神经网络的结构,而非结构剪枝则在不修改网络结构的情况下对部分权重进行归零。与非结构化剪枝相比,结构化剪枝不依赖特定的AI加速器或软件来减少内存消耗和计算成本,从而在实践中找到更广泛的应用领域。
-
然而,结构修剪本身的性质使其本身成为一项具有挑战性的任务,尤其是对于具有耦合和复杂内部结构的现代深度神经网络而言。其基本原理在于,深度神经网络建立在大量基本模块(如卷积、规范化或激活)之上,但这些模块(无论是否参数化)都通过错综复杂的连接在本质上耦合在一起 。因此,即使我们试图从图 1(a)所示的 CNN 中仅移除一个通道,我们也必须同时处理它与所有层的相互依赖关系,否则我们最终会得到一个损坏的网络。确切地说,残差连接要求两个卷积层的输出共享相同数量的通道,从而迫使它们一起修剪 。对于其他架构(如 Transformers、RNN 和 GNN)(如图 1(b-d)所示)的结构修剪也是如此。
-
图 1. 不同层的参数在网络架构中本质上相互依赖,这迫使多个层必须同时修剪。例如,要修剪 (a) 中的 Conv2,块内的所有其他层 {Conv1、BN1、BN2} 也必须修剪。在这项工作中,我们引入了一种通用方案,称为依赖图,以明确考虑此类依赖关系并以全自动方式执行任意架构的修剪。
-
不幸的是,依赖性不仅出现在残差结构中,在现代模型中,残差结构可能无限复杂 。现有的结构方法在很大程度上依赖于逐案分析来处理网络中的依赖性 。尽管取得了有希望的结果,但这种针对网络的修剪方法非常耗费精力。此外,这些方法不能直接推广,这意味着手动设计的分组方案不能转移到其他网络系列,甚至不能转移到同一系列中的网络架构,这反过来极大地限制了它们的工业应用。
-
在本文中,我们致力于为任何结构修剪制定一个通用方案,其中任意网络架构上的结构修剪都以自动方式执行,我们方法的核心是估计依赖图(DepGraph),它明确地模拟了神经网络中成对层之间的相互依赖关系。我们引入 DepGraph 进行结构修剪的动机源于这样的观察:某一层的结构修剪会有效地“触发”相邻层的修剪,这进一步导致类似 {BN2 <- Conv2 -> BN1 -> Conv1} 的链式效应,如图 1(a) 所示。因此,为了追踪不同层之间的依赖关系,我们可以分解和建模依赖链作为一个递归过程,这自然归结为在图中找到最大连通分量的问题,并且可以通过图遍历以 O(N) 复杂度解决。
-
还值得注意的是,在结构修剪中,分组层会同时修剪,这意味着同一组中所有被删除的参数始终不重要。这给现有的针对单个层设计的重要性标准带来了一定的困难。确切地说,由于与其他参数化层的纠缠,单个层中的参数重要性不再显示正确的重要性。为了解决这个问题,我们充分利用 DepGraph 提供的依赖关系建模的综合能力来设计一个“组级”重要性标准,它可以学习组内的一致稀疏性,这样就可以安全地删除那些零化的组,而不会造成太多的性能下降。
-
为了验证 DepGraph 的有效性,我们将所提出的方法应用于几种流行的架构,包括 CNN 、Transformers 、RNN 和 GNN ,与最先进的方法相比,其性能具有竞争力 。对于 CNN 剪枝,我们的方法获得了 2:57 倍加速的 ResNet-56 模型,在 CIFAR 上的准确率为 93.64%,甚至优于未剪枝模型的 93.53% 准确率。而在 ImageNet-1k 上,我们的算法在 ResNet-50 上实现了 2 倍以上的加速,而性能仅损失了 0.32%。
-
更重要的是,我们的方法可以很容易地转移到各种流行的网络,包括 ResNe(X)t 、DenseNet 、VGG 、MobileNet 、GoogleNet 和 Vision Transformer ,并显示出令人满意的结果。此外,我们还对非图像神经网络进行了进一步的实验,包括用于文本分类的 LSTM 、用于 3D 点云的 DGCNN 和用于图形数据的 GAT ,其中我们的方法实现了从 8 倍到 16 倍的加速,而性能没有显著下降。
-
总之,我们的贡献是一种针对任何结构修剪的通用修剪方案,称为依赖图(DepGraph),它允许自动参数分组并有效提高结构修剪在各种网络架构上的通用性,包括 CNN、RNN、GNN 和 Vision Transformers。
Related Work
-
结构化和非结构化剪枝。剪枝在网络加速领域取得了巨大进展,文献中的各种研究证明了这一点 。主流剪枝方法大致可分为两类:结构化剪枝 和非结构化剪枝 。 结构化剪枝旨在从物理上删除一组参数,从而减小神经网络的大小。相比之下,非结构化剪枝涉及在不改变网络结构的情况下将特定权重归零。在实践中,非结构化剪枝尤其易于实现,并且本质上适用于各种网络。 然而,它通常需要专门的 AI 加速器或模型加速软件 。相反,结构化剪枝通过从网络中物理删除参数来改善推理开销,从而找到更广泛的应用领域 。在文献中,剪枝算法的设计空间涵盖了一系列方面,包括剪枝方案 、参数选择、层稀疏性和训练技术。近年来,已经引入了许多稳健的标准,例如基于幅度的标准 、基于梯度的标准 和学习稀疏性 。最近,还进行了一项全面的研究,以评估各种标准的有效性并提供公平的基准 。
-
修剪分组参数。在复杂的网络结构中,参数组之间可能会出现依赖关系,因此需要同时修剪它们。耦合参数的修剪自结构修剪的早期以来一直是研究的重点。例如,在修剪两个连续的卷积层时,从第一层移除一个过滤器需要修剪后续层中的相关内核 。虽然手动分析参数依赖关系是可行的,但当应用于复杂网络时,这一过程可能极其耗费人力,正如许多先前的研究所示。此外,这种手动方案本质上无法转移到新的架构中,这严重限制了修剪的应用。最近,一些试点工作已经提出来解释层之间的复杂关系 。不幸的是,现有技术仍然依赖于经验规则或预定义的架构模式,因此它们对于所有结构修剪应用来说都不够通用。在本研究中,我们提出了一种解决这一挑战的通用方法,证明了解决参数依赖性可以有效地在各种网络中推广结构修剪,从而在多项任务上取得令人满意的性能。
Method
Dependency in Neural Networks
- 在本研究中,我们专注于在参数依赖性限制下对任何神经网络进行结构修剪。 在不失一般性的情况下,我们在全连接 (FC) 层上开发了我们的方法。让我们从一个由三个连续层组成的线性神经网络开始,如图 2 (a) 所示,分别由二维权重矩阵 wl、wl+1 和 wl+2 参数化。这个简单的神经网络可以通过移除神经元的结构修剪而变得纤细。在这种情况下,很容易发现参数之间出现了一些依赖关系,表示为 wl、wl+1,这迫使 wl 和 wl+1 同时被修剪。具体来说,为了修剪连接 wl 和 wl+1 的第 k 个神经元,wl [k; :] 和 wl+1[:; k] 都将被删除。在文献中,研究人员使用手动设计和特定于模型的方案处理层依赖性并在深度神经网络上实现结构修剪。然而,依赖关系种类繁多,如图 2(b-d)所示。手动逐个分析所有依赖关系非常困难,更不用说简单的依赖关系可以嵌套或组合以形成更复杂的模式。为了解决结构修剪中的依赖关系问题,我们在本文中引入了依赖图,它为依赖关系建模提供了一种通用且全自动的机制。
-
图 2. 不同结构中相互依赖的分组参数。所有突出显示的参数必须同时修剪。
Dependency Graph
-
分组。为了实现结构修剪,我们首先需要根据层的相互依赖性对它们进行分组。正式来说,我们的目标是找到一个分组矩阵 G ∈ R L × L G \in R^{L×L} G∈RL×L,其中 L 表示待修剪网络中的层数, G i j = 1 G_{ij} = 1 Gij=1 表示第 i 层和第 j 层之间存在依赖关系。为了方便起见,我们让 D i a g ( G ) = 1 1 × L Diag(G) = 1 ^{1×L} Diag(G)=11×L 实现自依赖性。使用分组矩阵,可以直接找到与第 i 层相互依赖的所有耦合层,表示为 g(i):
- g ( i ) = { j ∣ G i j = 1 } ( 1 ) g(i) = \{j|G_{ij} = 1\}~~ (1) g(i)={j∣Gij=1} (1)
-
然而,由于现代深度网络可能包含数千层复杂的连接,从而产生庞大而复杂的分组矩阵 G,因此从神经网络中估计分组模式并非易事。在这个矩阵中,Gij 不仅由第 i 层和第 j 层决定,还受到它们之间的中间层的影响。因此,在大多数情况下,这种非局部和隐式关系无法用简单的规则处理。 为了克服这一挑战,我们不直接估计分组矩阵 G,而是提出一种等效但易于估计的依赖关系建模方法,即依赖图,从中可以有效地得出 G。
-
依赖关系图。首先考虑一个组 g = { w 1 ; w 2 ; w 3 } g = \{w1; w2; w3\} g={w1;w2;w3},它具有依赖关系 w 1 ⇔ w 2 w1\Leftrightarrow w2 w1⇔w2、$w2\Leftrightarrow w3 $和 w 1 ⇔ w 3 w1\Leftrightarrow w3 w1⇔w3。仔细检查此依赖关系建模,我们可以发现存在一些冗余。例如,依赖关系 w 1 ⇔ w 3 w1\Leftrightarrow w3 w1⇔w3可以通过递归过程从 w 1 ⇔ w 2 w1\Leftrightarrow w2 w1⇔w2 和 w 2 ⇔ w 3 w2\Leftrightarrow w3 w2⇔w3 派生而来。最初,我们以 w1 为起点,并检查其对其他层(例如 w 1 ⇔ w 2 w1\Leftrightarrow w2 w1⇔w2)的依赖关系。然后,w2 为递归扩展依赖关系提供了一个新的起点,这反过来又“触发”了 w 2 ⇔ w 3 w2\Leftrightarrow w3 w2⇔w3。这个递归过程最终以传递关系 w 1 ⇔ w 2 ⇔ w 3 w1\Leftrightarrow w2\Leftrightarrow w3 w1⇔w2⇔w3 结束。在这种情况下,我们只需要两个依赖关系即可描述组 g 中的关系。 类似地,第 3.2 节中讨论的分组矩阵对于依赖关系建模也是多余的,因此可以压缩为具有更少边的更紧凑形式,同时保留相同的信息。我们证明,一个度量相邻层之间局部相互依赖关系的新图 D(称为依赖图)可以有效简化分组矩阵 G。依赖图与 G 的不同之处在于,它只记录具有直接连接的相邻层之间的依赖关系。图 D 可以看作是 G 的传递约简 ,它包含与 G 相同的顶点,但边尽可能少。正式构造 D 使得对于所有 Gij = 1,D 中顶点 i 和 j 之间存在一条路径。因此,可以通过检查 D 中顶点 i 和 j 之间是否存在路径来得出 Gij。
-
网络分解。然而,我们发现在层级构建依赖关系图在实践中可能会出现问题,因为一些基本层(如全连接层)可能有两种不同的修剪方案,如第 3.1 节中讨论的 w[k; :] 和 w[:; k],它们分别压缩输入和输出的维度。此外,网络还包含非参数化操作(如跳过连接),这也会影响层之间的依赖关系 。为了解决这些问题,我们提出了一种新的符号,将网络 F(x; w) 分解为更精细、更基本的组件,表示为 F = { f 1 ; f 2 ; : : : ; f L } F =\{ f_1; f_2; :::; f_L\} F={f1;f2;:::;fL},其中每个组件 f 指的是参数化层(如卷积)或非参数化操作(如残差添加)。我们不是在层级建模关系,而是专注于层的输入和输出之间的依赖关系。具体而言,我们分别将组件 fi 的输入和输出表示为 f −i 和 f +i。对于任何网络,最终分解都可以形式化为 F = { f 1 − ; f 1 + ; : : : ; f L − ; f L + } F =\{ f ^− _1 ; f ^+ _1 ; :::; f ^− _L ; f ^+ _L \} F={f1−;f1+;:::;fL−;fL+}。这种符号有助于更轻松地进行依赖关系建模,并允许对同一层使用不同的修剪方案。
-
依赖关系建模。利用这种符号,我们将神经网络重新绘制为公式 2,其中可以辨别出两种主要类型的依赖关系,即层间依赖关系和层内依赖关系,如下所示:
-
符号 ↔ \leftrightarrow ↔ 表示两个相邻层之间的连接。检查这两个依赖关系可以得出依赖关系建模的简单但通用的规则:
- 层间依赖性:依赖性 f i − ⇔ f j + f ^− _i \Leftrightarrow f ^+ _j fi−⇔fj+ 始终出现在连接的层中,其中 f i − ↔ f j + f ^− _i \leftrightarrow f ^+ _j fi−↔fj+ 。
- 层内依赖性:如果 f − i 和 f + i 共享相同的修剪方案,则存在依赖性 f i − ⇔ f j + f ^− _i \Leftrightarrow f ^+ _j fi−⇔fj+,记为 s c h ( f i − ) = s c h ( f i + ) sch(f ^− _i ) = sch(f ^+ _i ) sch(fi−)=sch(fi+)。
-
首先,如果已知网络的拓扑结构,就可以轻松估计层间依赖关系。对于具有 f i − ↔ f j + f ^− _i \leftrightarrow f ^+_j fi−↔fj+ 的连通层,依赖关系始终存在,因为 f − i 和 f + j 对应于网络的相同中间特征。后续步骤涉及阐明层内依赖性。层内依赖性要求同时修剪单个层的输入和输出。许多网络层满足此条件,例如批量标准化,其输入和输出共享相同的修剪方案,表示为 s c h ( f i − ) = s c h ( f i + ) sch(f ^− _i ) = sch(f ^+ _i ) sch(fi−)=sch(fi+),因此将同时修剪,如图 3 所示。相反,像卷积这样的层对其输入和输出具有不同的修剪方案,即 w [ : ; k ; : ; : ] ≠ w [ k ; : ; : ; : ] w[:; k; :; :] \neq w[k; :; :; :] w[:;k;:;:]=w[k;:;:;:] 如图 3 所示,导致 s c h ( f i − ) ≠ s c h ( f i + ) sch(f ^− _i ) \neq sch(f ^+ _i ) sch(fi−)=sch(fi+)。在这种情况下,卷积层的输入和输出之间没有依赖关系。根据上述规则,我们可以正式建立依赖关系模型,如下所示:
-
其中 ∪ 和 ^ 表示逻辑“或”和“与”运算,1 是指示函数,当条件成立时返回“True”。第一项检查由网络连接引起的层间依赖性,而第二项检查由层输入和输出之间的共享剪枝方案引入的层内依赖性。值得注意的是,DepGraph 是一个对称矩阵, D ( f i − ; f j + ) = D ( f j + ; f i − ) D(f ^− _i ; f ^+ _j ) = D(f ^+ _j ; f ^− _i ) D(fi−;fj+)=D(fj+;fi−)。因此,我们可以检查所有输入- 输出对来估计依赖关系图。 在图 3 中,我们可视化了具有残差连接的 CNN 块的 DepGraph。算法 1 和 2 总结了依赖关系建模和分组的算法。
-
图 3. 层分组是通过 DepGraph 上的递归传播实现的,从 f + 4 开始。在此示例中,由于上面说明的发散剪枝方案,卷积输入 f − 4 和输出 f + 4 之间没有层内依赖关系。
Group-level Pruning
-
在前面的部分中,我们开发了一种分析神经网络内依赖关系的通用方法,这自然会引出组级剪枝问题。评估分组参数的重要性对剪枝提出了重大挑战,因为它涉及多个耦合层。在本节中,我们利用一个简单的基于范数的标准 来建立一种组级剪枝的实用方法。给定一个参数组 g = { w 1 ; w 2 ; : : : ; w ∣ g ∣ } g =\{ w1; w2; :::; w|g|\} g={w1;w2;:::;w∣g∣},现有标准(如 L2 范数重要性 I ( w ) = ∣ ∣ w ∣ ∣ 2 ) I(w) = ||w||2) I(w)=∣∣w∣∣2) 可以为每个 w ∈ g w \in g w∈g 生成独立分数。估计组重要性的自然方法是计算聚合分数 P I ( g ) = ∑ w ∈ g I ( w ) P I(g) =\sum _{w\in g} I(w) PI(g)=∑w∈gI(w)。
-
不幸的是,由于分布和量级的差异,在不同层上独立估计的重要性得分很可能不具有加性,因此毫无意义。为了使这种简单的聚合适用于重要性估计,我们提出了一种稀疏训练方法,以在组级别稀疏化参数,如图 4 © 所示,以便可以安全地从网络中删除那些零化组。具体来说,对于每个具有 K 个可修剪维度(由 w[k] 索引)的参数 w,我们引入了一个简单的稀疏训练正则化项,定义为:
-
其中 I g ; k = ∑ w ∈ g ∣ ∣ w [ k ] ∣ ∣ 2 2 I_{g;k} = \sum _{w\in g} ||w[k]|| ^2 _2 Ig;k=∑w∈g∣∣w[k]∣∣22 表示第 k 个可剪枝维度的重要性,γk 表示应用于这些参数的收缩强度。我们使用可控指数策略来确定 γk,如下所示:
-
γ k = 2 α ( I g m a x − I g , k ) / ( I g m a x − I g m i n ) \gamma_k=2^{\alpha(I^{max}_g-I_{g,k})/(I^{max}_g-I^{min}_g)} γk=2α(Igmax−Ig,k)/(Igmax−Igmin)
-
其中,使用归一化分数来控制收缩强度 αk,在 [ 2 0 ; 2 α ] [2 ^0 ; 2 ^α] [20;2α] 的范围内变化。在本文中,我们对所有实验都使用常量超参数 α = 4。在稀疏训练之后,我们进一步使用简单的相对分数 I ^ g ; k = N ⋅ I g ; k / ∑ { T o p N ( I g ) } \hat I_{g;k} = N · I_{g;k}/\sum\{TopN(I_g)\} I^g;k=N⋅Ig;k/∑{TopN(Ig)} 来识别和删除不重要的参数。在实验部分,我们展示了这种简单的修剪方法与一致的稀疏训练相结合时,可以实现与现代方法相当的性能。
Experiments
Settings
- 本文专注于分类任务,并在各种数据集上进行了大量实验,例如用于图像分类的 CIFAR 和 ImageNet 、用于图分类的 PPI 、用于 3D 分类的 ModelNet 和用于文本分类的 AGNews 。对于每个数据集,我们在几种流行的架构上评估了我们的方法,包括 ResNe(X)t 、VGG 、DenseNet 、MobileNet 、GoogleNet 、Vision Transformers 、LSTM 、DGCNNs 和 Graph Attention Networks 。为了进行 ImageNet 实验,我们使用 Torchvision 中的现成模型作为原始模型。剪枝后,所有模型都将按照与预训练阶段类似的协议进行微调,学习率较小,迭代次数较少。
Results on CIFAR
-
性能。CIFAR 是一个小型图像数据集,被广泛用于验证剪枝算法的有效性。我们遵循现有工作 在 CIFAR-10 上剪枝 ResNet-56,在 CIFAR100 上剪枝 VGG 网络。如表 1 所示。我们报告了剪枝模型的准确率及其理论加速比,定义为 S p e e d U p = M A C s ( b a s e ) M A C s ( p r u n e d ) Speed Up =\frac {MACs(base)} {MACs(pruned)} SpeedUp=MACs(pruned)MACs(base)。请注意,ResRep 、GReg 等基线也部署了稀疏训练进行剪枝。我们的算法与现有基于稀疏性的算法之间的一个主要区别是,我们的剪枝器在所有分组层、卷积、批量归一化和全连接层中一致地促进稀疏性。通过这种改进,我们能够充分利用组结构来学习更好的稀疏性,从而提高剪枝模型的准确性。
-
表 1. CIFAR-10 和 CIFAR-100 上的修剪结果。
-
组稀疏度分布。如前所述,一致的稀疏度对于结构修剪非常重要,因为它会强制所有修剪后的参数始终不重要。 在图 5 中,我们可视化了图 4 © 和 (b) 中一致和不一致策略学习到的分组参数的范数。很容易发现,我们的方法在组级别产生了很强的稀疏性,这有利于识别不重要的参数。然而,在不同层上独立工作的不一致方法无法在各层之间产生一致的重要性,这可能导致组级别的非稀疏范数。
-
图 4. 学习不同的稀疏方案来估计分组参数的重要性。方法 (a) 用于非结构化剪枝,仅关注单个权重的重要性。方法 (b) 学习结构稀疏层 ,但忽略其他层中的耦合权重。我们的方法(如 © 所示)学习组稀疏性,将所有耦合参数强制为零,以便可以通过简单的幅度方法轻松区分它们。
Ablation Study
-
分组策略。为了进一步验证分组的有效性,我们在几个卷积网络上评估了不同的策略。策略主要包括:1)无分组:在单个卷积层上独立进行稀疏学习和重要性评估;2)仅卷积分组:组内的所有卷积层以一致的方式稀疏化。3)完全分组:组内的所有参数化层,例如卷积、批量归一化和全连接层,都一致地稀疏化。如表 2 所示,当我们忽略神经网络中的分组信息并单独稀疏每一层时,我们的方法的性能将显着下降,在某些情况下甚至会因过度修剪而崩溃。
-
表 2. 针对不同分组策略和稀疏度配置对 CIFAR-100 进行的消融研究。所提出的策略(完全分组)在稀疏训练期间考虑了所有参数化层,而其他策略仅利用部分层。报告了具有均匀层稀疏度或学习层稀疏度的修剪模型的准确率(%)。+:在某些情况下,我们的方法会将某些维度过度修剪为 1,这严重损害了最终的准确性。
-
Conv-only设置的结果表明,对部分参数进行分组有利于最终的性能,但是仍然忽略了组中的一些有用信息。 因此,通过全分组策略进一步提高剪枝精度是可行的。
-
学习稀疏性。层稀疏性也是修剪的重要因素,它决定了修剪后的神经网络的最终结构。表 2 提供了一些有关层稀疏性的结果。这项工作主要关注两种类型的稀疏性,即均匀稀疏性和学习稀疏性。对于均匀稀疏性,相同的修剪率将应用于不同的层,假设冗余在网络中均匀分布。然而,图 5 中的先前实验表明,不同的层并不是同样可修剪的。在大多数情况下,学习稀疏性优于均匀稀疏性,尽管有时它可能会过度修剪某些层,导致准确性下降。
-
图 5. 通过有分组和无分组的稀疏学习获得的组级稀疏性直方图,分别对应图 4 中的策略 © 和 (b)。
-
DepGraph 的通用性。表 2 中的结果也证明了我们的框架的通用性,它能够处理各种卷积神经网络。 此外,我们强调我们的方法与包含密集连接和并行结构的 DenseNet 和 GoogleNet 兼容。在以下部分中,我们将进一步展示我们的框架对更多架构的能力。
Towards Any Structural Pruning
-
DepGraph 的可视化。由于参数分组过程复杂,因此大型神经网络的修剪面临着相当大的挑战。但是,通过使用 DepGraph,可以轻松获得所有耦合组。我们在图 6 中提供了 DenseNet-121 、ResNet-18 和 Vision Transformers 的 DepGraph D 和派生的分组矩阵 G 的可视化。分组矩阵是从算法 2 中概述的 DepGraph 派生而来的,其中 G[i; j] = 1 表示第 i 层与第 j 层属于同一组。 DenseNet-121 在同一密集块内的各层之间表现出很强的相关性,从而导致在结构修剪期间出现较大的组。所提出的依赖图在处理复杂网络时被证明是有用的,因为手动分析此类网络中的所有依赖关系确实是一项艰巨的任务。
-
图 6. DenseNet-121、ResNet-18 和 ViT-Base 的依赖图(顶部)和派生的分组方案(底部)。
-
ImageNet。表 3 展示了几种架构在 ImageNet 上的剪枝结果,包括 ResNet、DenseNet、MobileNet、ResNeXt 和 Vision Transformers。这项工作的目标不是为各种模型提供最先进的结果,因此我们只使用最基本的重要性标准。我们表明,一个简单的基于规范的标准与依赖关系建模相结合,可以实现与使用强大标准 和 训练技术 的现代方法相当的性能。
-
表 3.ImageNet 上的修剪结果。
-
文本、3D 点云、图形等。除了 CNN 和 Transformer 之外,我们的方法也很容易应用于其他架构。此部分包括对各种数据的实验,包括文本、图形和 3D 点云,如表 4 所示。我们利用 LSTM 进行文本分类,研究 DepGraph 在递归结构上的有效性,其中参数化层由于元素操作而耦合。DepGraph 还在包含 3D 点云聚合操作的动态图 CNN 上进行了测试。此外,我们还对图形数据进行了实验,这需要与用于其他任务的架构完全不同的架构。在这个实验中,我们专注于图形注意网络的加速,每个 GNN 层内都有几个耦合层。考虑到这些数据集上缺乏有关修剪的工作,我们将 DepGraph 与 CNN 中的一些经典修剪方法相结合以建立我们的基线。结果表明,我们的方法确实可以推广到各种各样的架构。
-
表 4. 非图像数据的修剪神经网络,包括 AGNews(文本)、ModelNet(3D 点云)和 PPI(图形)。我们报告了 AGNews 和 ModelNet 的修剪模型的分类准确率(%)以及 PPI 的微 F1 分数。
Conclusion
-
在这项工作中,我们引入了依赖图,以便在各种神经网络上进行任何结构修剪。 据我们所知,我们的工作是首次尝试开发一种可应用于 CNN、RNN、GNN 和 Transformer 等架构的通用算法。
-
ResNet-18!这一类网络想必各位已经非常熟悉了,层与层之间存在大量残差连接,整体模型结构还是比较复杂的。在Torch-Pruning中,我们提供了许多基本层的剪枝函数,首先我们尝试一下剪枝ResNet-18的第一层,即model.conv1.
-
from torchvision.models import resnet18 import torch_pruning as tp model = resnet18(pretrained=True).eval() tp.prune_conv_out_channels(model.conv1, idxs=[0,1]) # 剪枝前两个通道 output = model(torch.randn(1,3,224,224)) # 尝试运行剪枝后的网络 # 然后报错 RuntimeError: running_mean should contain 62 elements not 64
-
错误信息显示running_mean应该是62维而不是64维,这是什么原因导致的呢?其实通过打印剪枝后网络结构我们就能很快定位到问题:由于我们移除了model.conv1的2个输出通道,model.bn1(64通道)与model.conv1(62通道)已经不适配了。剪枝后的结构如下所示:
-
ResNet((conv1): Conv2d(3, 62, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(layer1): Sequential((0): BasicBlock((conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))
-
-
为了解决这一问题,我们就需要对
model.bn1
做同样的修剪。然而进一步分析可以发现,这类依赖关系不止于此,例如紧邻的下一个卷积层的输入通道也需要修正,即model.layer1[0].conv1
,因为此时它仅接收64通道特征作为输入。我们继续对剪枝代码进行适当修改:-
from torchvision.models import resnet18 import torch_pruning as tp model = resnet18(pretrained=True).eval() tp.prune_conv_out_channels(model.conv1, idxs=[0,1]) # 剪枝前两个通道 tp.prune_batchnorm_out_channels(model.bn1, idxs=[0,1]) # 尝试修复bn tp.prune_batchnorm_in_channels(model.layer1[0].conv1, idxs=[0,1]) # 尝试修复紧邻的conv output = model(torch.randn(1,3,224,224)) # 尝试运行剪枝后的网络 # 报错 RuntimeError: The size of tensor a (64) must match the size of tensor b (62) at non-singleton dimension 1
-
可以发现问题出在残差操作上。残差的相加操作要求传入的两个tensor具有相同的空间尺寸,也就意味着剪枝后的Tensor通道数62和另一个tensor的通道数64不再匹配了。到这里,想必各位已经没有耐心继续研究这没完没了的剪枝了,也是时候搬出Torch-Pruning真正核心的功能,即依赖图**
DependencyGraph
**了。
-
-
DependencyGraph是Torch-Pruning框架的底层算法,它主要作用是“自动寻找耦合层“。在下面的例子中,我们用DependencyGraph来实现1.1中的剪枝程序:
-
import torch from torchvision.models import resnet18 import torch_pruning as tp model = resnet18(pretrained=True).eval() # 1. 构建依赖图 DG = tp.DependencyGraph() DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224)) # 2. 获取与model.conv1存在依赖的所有层,并指定需要剪枝的通道索引(此处我们剪枝第[2,6,9]个通道) group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9] ) # 3. 执行剪枝操作 if DG.check_pruning_group(group): # 避免将通道剪枝到0group.prune() output = model(torch.randn(1,3,224,224)) # 尝试运行剪枝后的网络
-
上述过程一共三步: 1) 对网络进行依赖图构建; 2) 选取需要剪枝的层,指定剪枝通道,获得分组group;3)执行剪枝操作,按组移除通道。运行程序可以发现这次没有任何报错,那么上述过程发生了什么呢?我们可以通过打印
group
来一探究竟: -
--------------------------------Pruning Group -------------------------------- [0] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), idxs=[2, 6, 9] (Pruning Root) [1] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9] [2] prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp(ReluBackward0), idxs=[2, 6, 9] [3] prune_out_channels on _ElementWiseOp(ReluBackward0) => prune_out_channels on _ElementWiseOp(MaxPool2DWithIndicesBackward0), idxs=[2, 6, 9] [4] prune_out_channels on _ElementWiseOp(MaxPool2DWithIndicesBackward0) => prune_out_channels on _ElementWiseOp(AddBackward0), idxs=[2, 6, 9] [5] prune_out_channels on _ElementWiseOp(MaxPool2DWithIndicesBackward0) => prune_in_channels on layer1.0.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9] [6] prune_out_channels on _ElementWiseOp(AddBackward0) => prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9] [7] prune_out_channels on _ElementWiseOp(AddBackward0) => prune_out_channels on _ElementWiseOp(ReluBackward0), idxs=[2, 6, 9] [8] prune_out_channels on _ElementWiseOp(ReluBackward0) => prune_out_channels on _ElementWiseOp(AddBackward0), idxs=[2, 6, 9] [9] prune_out_channels on _ElementWiseOp(ReluBackward0) => prune_in_channels on layer1.1.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9] [10] prune_out_channels on _ElementWiseOp(AddBackward0) => prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9] [11] prune_out_channels on _ElementWiseOp(AddBackward0) => prune_out_channels on _ElementWiseOp(ReluBackward0), idxs=[2, 6, 9] [12] prune_out_channels on _ElementWiseOp(ReluBackward0) => prune_in_channels on layer2.0.downsample.0 (Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)), idxs=[2, 6, 9] [13] prune_out_channels on _ElementWiseOp(ReluBackward0) => prune_in_channels on layer2.0.conv1 (Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), idxs=[2, 6, 9] [14] prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.1.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9] [15] prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.0.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9] --------------------------------
-
从Group内的信息可以得知,DepGraph所做的事情和我们在1.1中人工进行的分析非常相似。例如,规则[1]描述了“对model.conv1的剪枝会触发model.bn1的剪枝"这一现象。除此以外, group中还存在许多我们在1.1中没考虑到的依赖,这些依赖关系过于复杂,导致结构化剪枝变得很困难。我们可以打印剪枝后的网络来观察来哪些层被自动剪枝了:
-
ResNet((conv1): Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)(bn1): BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(layer1): Sequential((0): BasicBlock((conv1): Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(64, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(1): BasicBlock((conv1): Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(64, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(layer2): Sequential((0): BasicBlock((conv1): Conv2d(61, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(downsample): Sequential((0): Conv2d(61, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))
-
可以发现伴随着model.con1的通道剪枝,model.bn1, layer1.0.conv1, layer1.0.bn2 等等一系列层都需要进行调整,从而保证结构的正确性。而Torch-Pruning库的设计目标就是将这些依赖层的处理完全自动化,帮助我们快速找到与目标层相互耦合的其他层,从实现正确的结构化剪枝。
-
-
在上述内容里,我们讨论了单层剪枝问题。然而在实践中,我们更希望对整个网络进行剪枝,而非特定的某几层,这就涉及到如何不重复地遍历网络中所有分组的问题。DepGraph提供了接口
DG.get_all_groups
来实现这一目标,下面例子展示了遍历所有分组的一个最简单程序:-
for group in DG.get_all_groups(ignored_layers=[model.conv1], root_module_types=[nn.Conv2d, nn.Linear]):idxs = [2,4,6] # your pruning indicesgroup.prune(idxs=idxs)print(group)
-
get_all_groups接口包含两个参数,第一个参数ignored_layers用于忽略某些不希望被剪枝的层,通常包括最后的分类层、以及报错的层 (解决问题的最好办法是干掉提出问题的层) ,第二个参数root_module_types则指定了每个组的起始层类型。上述例子中我们主要关注卷积和全联接层的剪枝,因此我们传入对应的卷积类和线性层类。值得注意的是,不同层可能出现在同一个分组中,DepGraph会自动去除重复分组。此外,如果想要对某个分组进行修剪,我们需要手动设置通道的索引,因为get_all_groups 只对层进行分组,并不会分辨不同通道的重要性。
-
-
实际上,上述过程中的手动遍历分组、寻找冗余通道依旧过于复杂了,我们能否把事情变得更简单呢?Torch-Pruning基于DepGraph实现了剪枝器模块(High-level Pruners),提供了各种剪枝算法的封装实现。剪枝器提供了模型层面的整体剪枝能力
-
import torch from torchvision.models import resnet18 import torch_pruning as tp model = resnet18(pretrained=True) example_inputs = torch.randn(1, 3, 224, 224) # 1. 选择合适的重要性评估指标,这里使用权值大小 imp = tp.importance.MagnitudeImportance(p=2) # 2. 忽略无需剪枝的层,例如最后的分类层(总不能剪完类别都变少了叭?) ignored_layers = [] for m in model.modules():if isinstance(m, torch.nn.Linear) and m.out_features == 1000:ignored_layers.append(m) # DO NOT prune the final classifier! # 3. 初始化剪枝器 iterative_steps = 5 # 迭代式剪枝,重复5次Pruning-Finetuning的循环完成剪枝。 pruner = tp.pruner.MagnitudePruner(model,example_inputs, # 用于分析依赖的伪输入importance=imp, # 重要性评估指标iterative_steps=iterative_steps, # 迭代剪枝,设为1则一次性完成剪枝ch_sparsity=0.5, # 目标稀疏性,这里我们移除50%的通道 ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}ignored_layers=ignored_layers, # 忽略掉最后的分类层 ) # 4. Pruning-Finetuning的循环 base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs) for i in range(iterative_steps):pruner.step()macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)# finetune your model here# finetune(model)# ...
-
上述代码逻辑非常简单,首先我们选择合适的参数重要性评估(tp.importance),然后创建合适的剪枝器(tp.pruner),设置合适的稀疏度(=剪枝率),最后调用pruner.step()就完成了剪枝。以上代码就是利用Torch-Pruning对任意模型剪枝的基本流程。通过打印剪枝后的模型我们可以看到,在第5次迭代完成后,所有层的通道数都被正确地调整到了一半 (512 => 256):
-
... (layer4): Sequential((0): BasicBlock((conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(downsample): Sequential((0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): BasicBlock((conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(avgpool): AdaptiveAvgPool2d(output_size=(1, 1)(fc): Linear(in_features=256, out_features=1000, bias=True) ) torch.Size([1, 1000])Iter 5/5, Params: 11.69 M => 3.06 MIter 5/5, MACs: 1.82 G => 0.49 G
-
-
交互式剪枝:如果我们希望进一步了解Pruner在剪枝过程裁剪了哪些通道,我们可以使用Torch-Pruning提供的交互式的接口:
-
for i in range(iterative_steps):for group in pruner.step(interactive=True): # Warning: groups must be handled sequentially. Do not keep them as a list.print(group.details()) # 打印具体的group以及将要被剪枝的通道索引 # 此处可以插入自定义代码,例如监控、打印、分析等group.prune() # 交互地调用pruning完成剪枝# group.prune(idxs=[0, 2, 6]) # 也可以传入新的剪枝索引,覆盖pruner的剪枝行为
-
通过调用
pruner.step(interactive=True)
,我们就可以按顺序得到各个分组,以及pruner为我们自动选取的冗余通道。在遍历过程中,我们可以打印、分析剪枝情况,甚至根据自己的需要修改被剪枝的通道。
-
-
Torch-Pruning应用到了Torchvision v0.13.1提供的各种复杂的预训练模型上,涵盖分类、语义分割、检测、实例分割等任务。目前,项目已经支持了81/85=95.3%的模型。包括 Vision Transformers, Yolov7, FasterRCNN, SSD, KeypointRCNN, MaskRCNN, ResNe(X)t, ConvNext, DenseNet, ConvNext, RegNet, FCN, DeepLab。在此基础上,我们也对大语言模型LLaMA、常用的检测模型YOLO v7/v8的结构化剪枝进行了初步探索,具体代码请见Github的benchmark。总体而言,Torch-Pruning可以实现架构通用的结构化剪枝,显著降低结构化剪枝的应用门槛。
-
结构化剪枝是一种重要的模型压缩算法,它通过移除神经网络中冗余的结构来减少参数量,从而降低模型推理的时间、空间代价。在过去几年中,结构化剪枝技术已经被广泛应用于各种神经网络的加速,覆盖了ResNet、VGG、Transformer等流行架构。然而,现有的剪枝技术依旧存在着一个棘手的问题,即算法实现和网络结构的强绑定,这导致我们需要为不同模型分别开发专用且复杂的剪枝程序。
-
那么,这种强绑定从何而来?在一个网络中,每个神经元上通常会存在多个参数连接。如下图 2(a)所示,当我们希望通过剪枝某个神经元(高亮表示)实现加速时,与该神经元相连的多组参数需要被同时移除,这些参数就组成了结构化剪枝的最小单元,通常称为组(Group)。然而,在不同的网络架构中,参数的分组方式通常千差万别。图 2(b)-(d)分别可视化了残差结构、拼接结构、以及降维度结构所致的参数分组情况,这些结构甚至可以互相嵌套,从而产生更加复杂的分组模式。在传统的剪枝流程中,参数的分组通常是手工完成的,这不仅是一个费事且无趣的过程,还要求开发人员对网络结构以及层间依赖关系非常熟悉。因此,参数分组也是结构化剪枝算法落地的一个难题。
-
那么有没有一种通用方法能够自动分析哪些参数属于同一组呢?为了实现这一目标,本文提出了一种名为DepGraph的(非深度)图算法,来对任意网络中的参数依赖关系进行建模。在结构化剪枝中,同一分组内的参数是两两耦合的,当我们希望移除其中之一时,属于该组的参数都需要被同步移除,从而保证结构的正确性。理想情况下,我们希望构建一个二进制的分组矩阵G来记录所有参数对之间耦合关系:如果第 i 层的参数和第 j 层参数相互耦合,我们就用 Gij=1 来进行表示。如此,参数的分组就可以简单建模为一个查询问题:
-
g ( i ) = { j ∣ G i j = 1 } g(i)=\{j|G_{ij}=1\} g(i)={j∣Gij=1}
-
然而,参数之间是否相互依赖并不仅仅由自身决定,还会受到他们之间的中间层影响。实际上,中间层的结构有无穷种可能,这就导致我们难以基于规则直接判断参数的耦合性。在分析参数依赖的过程中,我们发现一个重要的现象,即相邻层之间的依赖关系是可以递推的。举个例子,相邻的层A、B之间存在依赖,同时相邻的层B、C之间也存在依赖,那么我们就可以递推得到A和C之间也存在依赖关系,尽管A、C并不直接连接。这就引出了本文算法的核心,即“利用相邻层的局部依赖关系,递归地推导出我们需要的分组矩阵G”。而这种相邻层间的局部依赖关系我们称之为依赖图(Dependency Graph),记作D。 依赖图是一张稀疏且局部的关系图,因为它仅对直接相连的层进行依赖建模。由此,分组问题可以简化成一个路径搜索问题,当依赖图D中节点i和节点j之间存在一条路径时,我们可以得到 Gij=1 ,即 i 和 j 属于同一个分组。
-
-
当我们把这个简单的想法应用到实际的网络时,我们会发现一个新的问题。结构化剪枝中同一个层可能存在两种剪枝方式,即输入剪枝和输出剪枝。对于一个卷积层而言,我们可以对参数的不同维度进行独立的修剪,从而分别剪枝输入通道或者输出通道。然而,上述的依赖图D却无法对这一现象进行建模。为此,我们提出了一种更细粒度的模型描述符,从逻辑上将每一层 f i f_i fi 拆解成输入 f i − f_i^− fi− 和输出 f i + f_i^+ fi+ 。基于这一描述,一个简单的堆叠网络就可以描述为:
- ( f 1 − , f 1 + ) ↔ ( f 2 − , f 2 + ) . . . ↔ ( f L − , f L + ) (f_1^-,f^+_1)\leftrightarrow (f^-_2,f^+_2)...\leftrightarrow (f^-_L,f^+_L) (f1−,f1+)↔(f2−,f2+)...↔(fL−,fL+)
-
其中符号 ↔ 表示网络连接。还记得依赖图是对什么关系进行建模么?答案是相邻层的局部依赖关系!在新的模型描述方式中,“相邻层”的定义更加广泛,我们把同一层的输入 fi− 和输出fi+ 也视作相邻。通常而言,每个神经网络中都会存在复杂的依赖关系,但是我们依旧从上式中抽象出两类基本依赖关系,即层间依赖(Inter-layer Dependency)和层内依赖(Intra-layer Dependency)。
-
层间依赖:首先我们考虑层间依赖,这种依赖关系由层和层直接的连接导致,是层类型无关的。由于一个层的输出和下一层的输入对应的是同一个中间特征(Feature),这就导致两者需要被同时剪枝。例如在通道剪枝中,“某一层的的输出通道剪枝”和“相邻后续层的输入通道剪枝”是等价的。
-
层内依赖:其次我们对层内依赖进行分析,这种依赖关系与层本身的性质有关。在神经网络中,我们可以把各种层分为两类:第一类层的输入输出可以独立地进行剪枝,分别拥有不同的剪枝布局(pruning scheme),记作 sch(fi+) 或者 sch(fi−) 。例如对于全连接层的2D参数矩阵 w ,我们可以得到 w[k,:] 和 w[:,k] 两种不同的布局。这种情况下,输入 fi− 和输出 fi+ 在依赖图中是相互独立、非耦合的;而另一类层输入输出之间存在耦合,例如逐元素运算、Batch Normalization等。他们的参数(如果有)仅有一种剪枝布局,且同时影响输入输出的维度。实际上,相比于复杂的参数耦合类型,深度网络中的层类型是非常有限的,我们可以预先定义不同层的剪枝布局来确定图中的依赖关系。
-
-
在算法1和算法2中总结了依赖图构建和参数分组的过程,其中参数分组是一个递归的连通分量(Connected Component)搜索问题,可以通过简单深度(DFS)或者宽度(BFS)优先搜索实现。算法2简要描述了这一过程,即我们以某个节点i作为起始分组g,找到依赖图D中与之相连的新节点j,合并入当前组,直到不存在新的联通节点为止。此处我们省略了分组的去重处理。
-
依赖图的一个重要作用是参数自动分组,从而实现任意架构的模型剪枝。实际上,依赖图的自动分组能力还可以帮助设计组级别剪枝(Group-level Pruning)。在结构化剪枝中,属于同一组的参数会被同时移除,这一情况下我们需要保证这些被移除参数是“一致冗余”的,如果这些参数中包含对网络预测至关重要的参数,那么移除这些参数难免会损伤性能。
定义不同层的剪枝布局来确定图中的依赖关系。 -
然而,一个常规训练的网络显然不能满足这一要求。这就需要我们引入稀疏学习方法来对参数进行稀疏化。这里同样存在一个问题,常规的逐层独立的稀疏技术实际上是无法实现这一目标,因为逐层算法并不考虑层间依赖关系,从而导致图 3 (b)中非一致稀疏的情况。为了解决这一问题,我们按照依赖关系将参数进行打包,如图 3 ©所示,进行一致的稀疏训练(虚线框内参数被推向0),从而使得耦合的参数呈现一致的重要性。在具体技术上,我们采用了一个简单的L2正则项,通过赋予参数组的不同正则权重 γ 来进行组稀疏化。