部署神经网络时计算图的优化方法
部署神经网络时,各路框架基本都会把神经网络的计算建模为一个(有向无环的)计算图,之后再对这个计算图进行优化,包括硬件相关的优化和硬件无关的优化。本文介绍几种部署神经网络时计算图的优化方法,帮助读者在部署神经网络时理解部署工具都干了些什么。
算子融合
最关键的优化计算图的方式就是算子融合了,算子融合指的是将多个神经网络算子(例如卷积、池化、归一化等)组合在一起,以提高计算效率和性能。
输入卷积层与归一化融合
卷积神经网络中,输入的图像往往要做一个Normalization,比如ImageNet上训练的神经网络经常需要进行下面这个操作:
std = [0.229, 0.224, 0.225]
mean = [0.485, 0.456, 0.406]
x = (x/255 - mean) / std
而YoloV5这样的模型则更简单,mean是0,std是1:
x = x / 255
在第一层卷积时,我们会将卷积核在图像上进行滑动,用卷积核的参数乘以对应位置的像素(然后加上偏置),即
y = wx+b
我们把Normalization的过程代入上面这个式子并化简:
y = w(x/255-mean)/std + b
# 化简后
y = (w/255/std)x + (b-w*mean/255/std)
这个过程中,我们发现式子可以看作一个新的y=wx+b
,新的w是w/255/std
,新的b则是b-w*mean/255/std
。所以,我们可以提前对卷积核参数进行变化,从而将两个算子合并为一个算子。
这是一个最简单的例子,其它的算子融合和上述操作的原理是类似的。
矩阵乘法操作 + 激活函数 融合
神经网络里最常见的还有矩阵乘法操作 + 激活函数的组合,比如卷积层后面紧跟着一个ReLU,或者Transformer的FFN中Linear跟着一个GeLU等。 激活函数可以在计算矩阵乘法的同时计算,下面是用c++写的一个伪代码的例子:
// 矩阵乘法 + ReLU的融合
void MatMul(input, weight, bias, output, num_of_elements) {for (int i = 0; i < num_of_elements; i++){for (int j = 0; j < num_of_elements; j++){int accumulator = 0;for (int k = 0; k < num_of_elements; k++){accumulator += input[i][k] * weight[k][j];}// reluoutput[i][j] = accumulator + bias[j] > 0 ? accumulator + bias[j]: 0;}}
}
矩阵乘法就是不断计算向量的点积,即一系列数字相乘后求和,而激活函数基本是对数字进行一个非线性的变换,所以非线性变换的过程可以求和得到结果之后马上进行,而不是进行完所有的矩阵乘法之后,再读一遍矩阵进行非线性的变换。
线性变换 + BatchNorm 融合
线性变换+BN也是一个常见的组合,比如YoloV5中就有很多的CBS(Conv+BN+SiLU)的组合(这里把Conv就是局部的线性变换)。
这种融合可以通过简化合并公式的方式进行:
Y = W x + B Y ′ = γ Y − m e a n v a r + β Y ′ = γ W x + B − m e a n v a r + β Y ′ = γ v a r W x + γ v a r ( B − m e a n ) + β Y=Wx+B \quad Y^\prime=\gamma\frac{Y-mean}{var}+\beta \\ Y^\prime = \gamma\frac{Wx+B-mean}{var}+\beta \\ Y^\prime = \frac{\gamma}{var}Wx + \frac{\gamma}{var}(B-mean) + \beta Y=Wx+BY′=γvarY−mean+βY′=γvarWx+B−mean+βY′=varγWx+varγ(B−mean)+β
常量折叠
这个是在代码编译领域也常用的优化方法,举个最简单的例子,当我们写python的时候,我们想让某个程序暂停两个小时,则一般会写time.sleep(2*60*60)
,这个时候假如让程序在运行时再计算2*60*60
就会略显繁琐,所以编译器会提前把2*60*60
用7200
代替。
更复杂的例子涉及到多个变量,比如a=5;b=a+3;c=b+a
,这个时候优化的方式就是提前计算出b,c的值。
公共子表达式消除
公共表达式是传统编程语言编译器常用的优化的一种,在程序中计算表达式时,有时会出现公共的子表达式,重复计算这些子表达式会增加计算开销。
比如下面这个例子:
temp = b * c
a = b * c + g
d = b * c + e
计算a和d的时候会重复计算b*c,假如我们只计算一次temp = b * c
,然后计算a=temp+g, d=temp+e
,就能提高效率。
到了神经网络领域,可能会出现下图左边这个情况,此时需要合并成右边这种情况,从而简化计算图。
死代码消除
这也是传统编程语言编译器中的一种优化方式,比如下面这个代码,return
之后的代码是unreachable的,所以编译后应该完全消除这部分。
def test(flag):print("Flag is False.")returnprint("This code is unreachable.")
到了神经网络上,在计算图上其实能比较直观地发现没有用的节点或者不可达的节点。比如有一个节点孤立在计算图外面;或者某个节点有输入没输出且不是输出节点;或者某个节点有输出没输入且不是输入节点。
总结
神经网络编译器和传统编程语言编译器非常相似,其许多优化技术都是从编程语言编译器中沿用而来,但是神经网络编译器也有它的特点,有新的例如算子融合的优化方法可以用。这些优化方式能够对神经网络的部署起到关键的作用。
总结
神经网络编译器和传统编程语言编译器非常相似,其许多优化技术都是从编程语言编译器中沿用而来,但是神经网络编译器也有它的特点,有新的例如算子融合的优化方法可以用。这些优化方式能够对神经网络的部署起到关键的作用。