深度学习模型轻量化(下)

news/2024/12/2 6:55:39/

深度学习模型轻量化(下)

2.4 蒸馏

2.4.1 蒸馏流程

蒸馏本质是student对teacher的拟合,从teacher中汲取养分,学到知识,不仅仅可以用到模型压缩和加速中。蒸馏常见流程如下图所示

在这里插入图片描述

  1. 老师和学生可以是不同的网络结构,比如BERT蒸馏到BiLSTM网络。但一般相似网络结构,蒸馏效果会更好。

  2. 总体loss为 soft_label_loss + hard_label_loss。soft_label_loss可以用KL散度或MSE拟合

  3. soft label为teacher模型的要拟合的对象。可以是模型预测输出,也可以是embeddings, 或者hidden layer和attention分布。

针对软标签的定义,蒸馏的方案也是百花齐放,下面分享两篇个人认为非常经典的文章。

2.4.2 distillBERT

DistillBERT: A distilled version of BERT: smaller,
faster, cheaper and lighter

DistillBERT由大名鼎鼎的HuggingFace出品。主要创新点为:

  1. Teacher 12层,student 6层,每两层去掉一层。比如student第二层对应teacher第三层

  2. Loss= 5.0 * Lce+2.0 * Lmlm+1.0 * Lcos

·
Lce: soft_label 的KL散度

·
Lmlm: mask LM hard_label
的交叉熵

·
Lcos:hidden state 的余弦相似度

DistilBERT 比 BERT 快 60%,体积比 BERT 小 60%。在glue任务上,保留了 95% 以上的性能。在performance损失很小的情况下,带来了较大的模型压缩和加速效果。

在这里插入图片描述

2.4.3 TinyBERT

TinyBERT: Distilling BERT for Natural Language
Understanding

总体结构

重点来看下 TinyBERT,它是由华为出品,非常值得深入研究。TinyBERT 对 embedding 层,transformer层(包括hidden layer和attention)和 prediction 层均进行了拟合。如下图所示。
在这里插入图片描述
TinyBERT蒸馏过程

其中Embeddings采用MSE, Prediction采用KL散度, Transformer层的hidden layer和attention,均采用MSE。loss如下

在这里插入图片描述

其中m为层数。

效果分析
在这里插入图片描述
表2: glue任务上的performance。在glue任务上,可达到bert-base的96%,几乎无损失。表3: tinyBERT模型大小和推理速度。缩小7.5倍,加速9.4倍。压缩和加速效果十分明显。

消融分析
在这里插入图片描述
表6:分析embedding、prediction、attention、hidden layer软标签作用,其中attention和hidden layer作用最大。这个也很好理解,transformer层本来就是整个BERT中最关键的部分。

在这里插入图片描述

表7:分析老师学生不同层对应方法的效果,uniform为隔层对应,top为全部对应老师顶部几层,bottom为全部对应老师底部几层。Uniform效果明显好很多。这个也很好理解,浅层可以捕捉低阶特征,深层可以捕捉高阶特征。全是低阶或者高阶显然不合适,我们要尽量荤素搭配。

3 框架层加速

3.1 手机端AI能力

目前移动端AI框架也比较多,包括谷歌的tf-lite,腾讯的NCNN,阿里的MNN,百度的PaddleLite,
小米的MACE等。他们都不同程度的进行了模型压缩和加速的支持。特别是端上推理的加速。这个可以参考“手机端AI性能排名“。

3.2 端侧AI框架加速优化方法

个人总结的主要方法如下,可能有遗漏哈,各位看官请轻拍:

  1. 基于基本的C++编译器优化。

a. 打开编译器的优化选项,选择O2等加速选项。

b. 小函数内联,概率大分支优先,避免除法,查表空间换时间,函数参数不超过4个等。

  1. 利用C,而不是C++,C++有不少冗余的东西。

  2. 缓存优化

a. 小块内存反复使用,提升cache命中率,尽量减少内存申请。比如上一层计算完后,接着用作下一层计算。

b. 连续访问,内存连续访问有利于一次同时取数,相近位置cache命中概率更高。比如纵向访问数组时,可以考虑转置后变为横向访问。

c. 对齐访问,比如224224的尺寸,补齐为256224,从而提高缓存命中率。

d. 缓存预取,CPU计算的时候,preload后面的数据到cache中。

  1. 多线程。

a. 为循环分配线程。

b. 动态调度,某个子循环过慢的时候,调度一部分循环到其他线程中。

  1. 稀疏化

a. 稀疏索引和存储方案,采用eigen的sparseMatrix方案。

  1. 内存复用和提前申请

a. 扫描整个网络,计算每层网络内存复用的情况下,最低的内存消耗。推理刚开始的时候就提前申请好。避免推理过程中反复申请和释放内存,避免推理过程中因为内存不足而失败,复用提升内存访问效率和cache命中率。

  1. ARM NEON指令的使用,和ARM的深度融合。NEON可以单指令多取值(SIMD),感兴趣可针对学习,这一块水也很深。

  2. 手工汇编,毕竟机器编译出来的代码还是有不少冗余的。可以针对运行频次特别高的代码进行手工汇编优化。当然如果你汇编功底惊天地泣鬼神的强,也可以全方位手工汇编。

  3. 算子支持:比如支持GPU加速,支持定点化等。有时候需要重新开发端侧的算子。

4 硬件层加速

硬件层加速比较硬核,小编就连半瓢水都达不到了,为了保证整个方案的全面性,还是硬着头皮东施效颦下。目前AI芯片厂家也是百花齐放,谁都想插一脚,不少互联网公司也来赶集,如下图所示。
在这里插入图片描述

AI 芯片目前三种方案。GPU目前被英伟达和AMD牢牢把控。ASIC目前最火,TPU、NPU等属于ASIC范畴。

在这里插入图片描述

5 总结

本文对深度学习模型压缩和加速的几类常用的方法进行了介绍。

参考文献

  1. ALBERT: A Lite BERT for Self-supervised Learning of
    Language Representations

  2. MobileNets: Efficient Convolutional Neural Networks for
    Mobile Vision Applications

  3. Are Sixteen Heads Really Better than One?

  4. DistillBERT: A distilled version of BERT: smaller, faster,
    cheaper and lighter

  5. TinyBERT: Distilling BERT for Natural Language
    Understanding

  6. 手机端AI性能排名


http://www.ppmy.cn/news/614462.html

相关文章

python---websocket的使用

一:简介 推文:WebSocket 是什么原理?为什么可以实现持久连接? 推文:WebSocket:5分钟从入门到精通(很好) WebSocket协议是基于TCP的一种新的协议。WebSocket最初在HTML5规范中被引用为…

php pthread安装编译,php 多线程扩展 pthreads 安装 及 使用

1、扩展的编译安装php(Linux),编辑参数 --enable-maintainer-zts 是必选项:2、下载 php7:http://tw2.php.net/get/php-7.1.2.tar.gz/from/a/mirrorduoxc3、解压并编译phptar -zxf php-7.1.2.tar.gzcd php-7.1.2./configure --prefix/usr/loca…

MAML-Tracker: 目标跟踪分析:CVPR 2020(Oral)

MAML-Tracker: 目标跟踪分析:CVPR 2020(Oral) Tracking by Instance Detection: A Meta-Learning Approach 论文链接:https://arxiv.org/abs/2004.00830 摘要 把跟踪问题看作一类特殊的目标检测问题,称之为实例检测。通过适当的初始化&am…

02Lua入门

前言:语言学起来其实相似点很多,简单整理的知识点目录: 1.使用控制台 2.Lua基础 3.变量 4.运算符 5.控制结构 1.使用控制台 Lua脚本是包含一系列Lua命令的简单脚本(扩展名为.lua的文本文件)。Lua不关注格式&#xff0c…

php无表单上传文件,php – 来自表单的WP邮件附件,无文件管理器上传文件

从表单通过wp_mail函数我正在尝试发送带附件的电子邮件,而不将文件上传到文件管理器.我收到附件的电子邮件.但附件名称不正确,没有文件类型.请帮忙解决这个问题.这是HTML表单有我的PHP代码if (isset($_POST[Submit])) {$attachments $_FILES[Attached][tmp_name];$recipients …

目标跟踪算法

目标跟踪算法 一.互相关运算 给你一张我的正脸照(没有经过美颜处理的),你该如何在人群中找到我呢?一种最直观的方案就是:“谁长得最像就是谁”。但是对于计算机来说,如何衡量“长得像”&#…

Spring Boot中的@RequestMapping注解,如何使用

Spring Boot中的RequestMapping注解 介绍 Spring Boot是一个流行的Java框架,它提供了许多方便的注解和工具,使得Web应用程序的开发变得更加容易。其中,RequestMapping注解是Spring Boot中最常用的注解之一,它可以帮助开发者定义…

JS继承的实现方式

原型链继承://原型链继承:把父类的私有公有的属性和方法,都作为子类公有的属性;//核心:不是把父类私有公有的属性克隆一份一模一样的给子类的公有吧;他是通过__proto__建立和子类之间的原型链,当…