第十四章 原理篇:DEIT

news/2024/11/9 9:24:09/

参考教程:
https://arxiv.org/pdf/2012.12877.pdf
https://github.com/facebookresearch/deit

文章目录

  • 概述
  • Knowledge Distillation
  • DEIT
    • base model: VIT
      • transformer block
      • class token
      • position embedding
    • Distillation through attention
  • 代码实现
    • DistilledVisionTransformer
      • __init__()
      • forward()
        • embedding
        • logits
    • distillation loss
      • __init__()
      • forward()

概述

在之前的章节中提到过,VIT模型训练的一个问题是对数据的要求比较高,因为基于transformer的模型相对于基于卷积的模型,更加flexible。卷积的模型有着预设好的感受野,而transformer的模型需要自己去学习哪部分更加重要,因此训练上也更困难。

在这种情况下,想独自训练一个效果比较好的transformer模型是很困难的,你很难准备大几百万的数据集用于训练。这也给论文复现带来了难度,你看别人的模型效果好,你想去学习,但是没有资源训练出相当的模型。

DEIT提出了一种基于token的蒸馏方法,使用和训练卷积网络差不多的时间,只用imagenet作为训练集,就实现了非常不错的效果。

总的来说,DEIT做出了以下贡献(这一段直接翻译的论文原文):

  • 证明了不包含卷积层的网络在只是用ImageNet数据的情况下也能取得很有竞争力的表现。
  • 提出了一种基于token的蒸馏方法,并且这个方法的效果明显超过了普通的蒸馏方法。
  • 有趣的是,基于transformer的模型以convnet为老师时表现的比以transformer为老师时要好。
  • 他们的基于imagenet预训练的模型应用于其它下游任务时效果也很不错。

Knowledge Distillation

在这里补充一点知识蒸馏相关的内容。

知识蒸馏简单来说呢,就是把我们想要训练的模型当作“学生”模型,在向我们的hard label,也就是ground truth的结果靠近的同时,也让它向一个“老师”模型(一般是一个效果更好的、体量更大的模型)输出的soft label靠近。

比较简单的方法就是直接让学生模型的输出logits去拟合老师模型的输出logits,复杂一点的会增加层与层之间的拟合。

下面的代码就来自一个比较早期的repohttps://github.com/haitongli/knowledge-distillation-pytorch/tree/master

可以看到KD_loss明显有两部分组成。

T = params.temperature
KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + F.cross_entropy(outputs, labels) * (1. - alpha)

第一个部分就是我们的软目标损失,使用KLD散度计算输出的logits与老师模型输出的logits的差距,T在这里是一个温度系数,T越大得到的概率分布就越平滑。第二个部分就是我们的硬目标损失,也就是输出与label的交叉熵损失。

DEIT

base model: VIT

首先来重新介绍一下DEIT方法中使用的模型框架,其实也就是复习了一遍VIT。

transformer block

DEIT的工作是在VIT模型的基础上完成的。使用固定大小的RGB图像作为输入,这个图像被拆解成N个大小为16*16的小patch,N的大小一般是14*14。也就是说默认图像的大小是224*224。

每个patch都会被处理成一个指定维度的token。在之前的章节中我们介绍过这里有两个常用做法,再次复述一下。

第一种做法是使用reshape之后,使用全连接层完成维度的变化。

self.proj = Rearrange('b c (h p) (w p ) -> b (h w) (p1 p2 c)', p = patch_size)
self.linear = nn.Linear(patch_size * patch_size * in_c, embed_dim)

第二种做法是直接使用卷积。

self.proj = nn.Conv2d(in_c, embed_dim, kernel_size = patch_size, stride=patch_size)

目前来说第二种方法是更常用的。

然后再给得到的embedding加上一个class_token和一个position_embeddings。就构成了一个完整的输入。

class token

VIT中模仿BERT的做法,在得到的patch embedding上concat了一个可训练的class token。这个class token也会贯穿整个网络,并且最终用于分类。它相当于起到了串联所有patch_embedding的作用,它包含的也是一个整体的信息。

也就是说在整个过程中,transformer一共使用了N+1个token,但是只有第一个class token被用来进行结果的预测。

position embedding

已知transformer中最重要的结构就是MSA,在MSA中会根据你的输入计算三个vector,分别是Query, Key, Value。并使用Q和K的内积计算attention。

我们直接看一下源码,可以看到这个qkv是通过全连接得到的,它完成的是从embed_dim到embed_dim的映射,这个过程是和embed的数量无关的。

self.qkv = nn.Linear(emb_size, emb_size*3)

所以一个在low-resolution的图像上训练的模型,也是很容易用在high-resolution的图像上的。只要使用一样的patch_size就可以。

这时候聪明的你可能会发现一个问题,patch_size大小一样,在high-resolution图像上得到的patch的数量肯定比low-resolution要多呀。那么position_embedding是会受到影响的,position_embedding的大小是和我们的数量以及embed_size都有关系的。

 self.positions = nn.Parameter(torch.randn((img_size // patch_size) **2 + 1, emb_size))

原VIT论文中的做法是这样的

We therefore perform 2D interpolation of the pre-trained position embeddings, according to their location in the original image.

Distillation through attention

作者在论文中对蒸馏的部分进行了比较详细的介绍。

soft distillation
软蒸馏就是上面介绍的,用学生模型的logits向老师模型的logits学习,两者的差距使用KL散度来衡量。

hard distillation
硬蒸馏是将老师模型预测的结果也作为真实的标签,让你的学生模型也去学习这个标签。
L g l o b a l h a r d D i s t i l l = ( 1 − ϵ ) × 1 2 L C E ( ψ ( Z s ) , y ) + ϵ × 1 2 L C E ( ψ ( Z s ) , y t ) L^{hardDistill}_{global} = (1-\epsilon)\times\frac{1}{2}L_{CE}(\psi(Z_s),y) + \epsilon\times\frac{1}{2}L_{CE}(\psi(Z_s),yt) LglobalhardDistill=(1ϵ)×21LCE(ψ(Zs),y)+ϵ×21LCE(ψ(Zs),yt)
这种实现方法也更简单方便。老师模型预测的label和ground truth的label扮演一样的角色。

在这里插入图片描述
Distillation token
上图介绍了DEIT是如何进行token的蒸馏的。它们在原有的patch embedding的基础上(patch and class token)新增了一个额外的token,称为distillation token
distillation token和class token一样,在整个训练过程中和别的token进行交互,并在最后一层输出。
class_token的分类结果向ground_truth靠齐,distillation_token的分类结果向我们的teacher靠齐。

整体的原理还是很简单的,可以看作class_token和distillation_token各学各的,在最后测试的时候,两个token是合在一起使用的。

代码实现

DistilledVisionTransformer

参考的是这里的源码:https://github.com/facebookresearch/deit/blob/main/models.py

我们首先来看一下这个DEIT的类。

class DistilledVisionTransformer(VisionTransformer):

它是直接继承的VisionTransformer的类,并在此基础上进行了一些修改,这个修改也没有很大,比较好理解。

init()

首先,它增加了一个dist_token,这个token和class_token的大小是完全一样的,用一样的代码就可以定义。

self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None

然后它的position_embedding和之前不一样了。在不使用蒸馏的时候position_embedding的长度 = num_patch + 1 (class_token)。现在增加了一个新token,所以它的长度也增加了1,变成了num_patch + 2。

self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))

此外,除了原有的分类头外,现在增加了一个新的蒸馏头,用来预测distillation_token的结果。这个部分代码和之前的分类头也是一样的。

self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
 self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

forward()

在模型的forward中,之前只有一个输出,现在变成了两个。整体的流程是没有什么变化的。

在之前的章节中我们梳理过VIT的流程。

  1. 输入img,获得patch,并转成embedding的形式。
  2. 增加cls embedding和position embedding。
  3. 进入transformer encoder构成的blocks。每个block由两部分组成:
    1. multi-head attention
    2. mlp
  4. 进入mlp分类头,输出结果。

在DEIT中增加了distillation_token,所以流程变为了:

  1. 输入img,获得patch,并转成embedding的形式。
  2. 增加cls embedding和dist embedding和position embedding。
  3. 进入transformer encoder构成的blocks。
  4. cls token进入mlp分类头,dist token进入另一个分类头

第二点主要是输入的维度发生了变化,对整个训练流程是没有影响的。最后一点也不过是分开了两个输出。

embedding

在原版VIT中。

if self.cls_token is not None:x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)x = x + self.pos_embed

在DIET中。

cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanksdist_token = self.dist_token.expand(B, -1, -1)x = torch.cat((cls_tokens, dist_token, x), dim=1)

logits

在原版VIT中。

def forward_head(self, x, pre_logits: bool = False):# 这里的x是self.forward_features的结果。if self.global_pool:x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]x = self.fc_norm(x)return x if pre_logits else self.head(x)

在DEIT中。
如果是训练中使用,两个结果分开输出,因为要分别计算loss。如果是在inference中,则使用两个输出融合的结果。

 def forward(self, x):x, x_dist = self.forward_features(x)x = self.head(x)x_dist = self.head_dist(x_dist)if self.training:return x, x_distelse:# during inference, return the average of both classifier predictionsreturn (x + x_dist) / 2

distillation loss

除了模型代码的改动外,DEIT中使用的loss也和之前不一样。
我们先来看一下loss的这个类。

init()

class DistillationLoss(torch.nn.Module):"""This module wraps a standard criterion and adds an extra knowledge distillation loss bytaking a teacher model prediction and using it as additional supervision."""def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,distillation_type: str, alpha: float, tau: float):super().__init__()self.base_criterion = base_criterionself.teacher_model = teacher_modelassert distillation_type in ['none', 'soft', 'hard']self.distillation_type = distillation_typeself.alpha = alphaself.tau = tau

这里传入的base_criterion是你打算用来计算你的分类损失的loss,也就是你的class_head预测的结果和你的图像类别的ground_truth的loss。

    if mixup_active:# smoothing is handled with mixup label transformcriterion = SoftTargetCrossEntropy()elif args.smoothing:criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)else:criterion = torch.nn.CrossEntropyLoss()if args.bce_loss:criterion = torch.nn.BCEWithLogitsLoss()

第二个参数teacher model是你想要学习的老师模型,因为我们只用这个模型做预测,不用它参与训练,所以要注意使用

teacher_model.eval()

第三个参数distillation_type是让你选择你先用软标签还是硬标签的方法。
第四个参数alpha用于分类损失和蒸馏损失的权重分配。
第四个参数tau就是温度系数,在软标签才会用到。

forward()

def forward(self, inputs, outputs, labels)

损失函数forward的部分的输入有三个,第一个input是我们的原始输入,它会被送入teacher_model中用于计算teacher_model的输出。第二个outputs是我们的学生模型的输出结果,它实际上包括了output(head的输出)和output_kd(dist_head)的输出。第三个labels就是我们的ground truth。

我们的分类损失直接用self.base_criterion进行计算。

base_loss = self.base_criterion(outputs, labels)

蒸馏损失按照你选择的distillation_type可以分为两类:soft和hard。其实还有一个选项是None,这种情况下不使用蒸馏损失。

teacher_outputs = self.teacher_model(inputs)

假如你使用软损失。那么就是用你的dist_head的logits和teacher_model的logits进行比较。在计算中还是使用KL散度。并且这里还会用到我们的温度系数tau。

T = self.tau
distillation_loss = F.kl_div(F.log_softmax(outputs_kd/T,dim=1), F.log_softmax(teacher_outputs/T, dim=1),reduction='sum',log_target=True)*(T*T)/outputs_kd.numel()
# We provide the teacher's targets in log probability because we use log_target=True 

假如你使用的是硬损失。那么就是用你的dist_head的logits和teacher_model输出的标签进行比较。

distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))

最终输出的loss用alpha这个参数平衡了权重。

loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha

http://www.ppmy.cn/news/768961.html

相关文章

接口返回慢 图片加载失败问题

该图片是通过后端接口返回picUrl来给img的src赋值 但是后端接口响应较慢 导致html加载完 data也没赋到值 图片加载裂开 解决办法&#xff1a;img标签添加error事件获取错误信息回调 <img v-if"certificateUrl" :src"certificateUrl" class"certif…

mysql sysdate()不走索引问题解析

1.常规比较 结论: SYSDATE()返回执行时间,而NOW()返回时间常量 2.索引比较 结论: 因为SYSDATE()是非确定性的,索引不能用于评估求值引用它的表达式,所以不走索引

怎么给旧版本ios装旧版本软件

1.准备软件 软件下载地址&#xff1a; itunes版本11&#xff1a;iTunes 11 Windows 64位版官方下载 - EFreeLife (12版本及以下) 爱思助手&#xff1a;什么版本都可以 fideer&#xff08;什么版本都可以&#xff09;&#xff1a;小提琴手|Web 调试代理和故障排除解决方案 …

旧版本ios软件整理

部分软件可能不兼容ios11,请自测。 1&#xff0e;Annotable V1.11.2 ID:820317675 内购自动获得版 2&#xff0e;PriceTag V2.3.1 ID:825133185 内购自动获得版 3&#xff0e;Aloha V2.0.3 ID:824183267&#xff08;中区已下架&#xff09;视频缓存利器&#xff0c;支持第三…

Perfetto导入文件查看

Perfetto导入文件查看 1、Perfetto导入2、查看 -- 这里查看InputEvent相关信息2.1 快捷键2.2 查看 InputReader 线程2.3 查看 InputDispatcher 线程2.4 查看某个应用接受InputEvent2.4 InputEvent相关数据2.4.1 aq:pending 应用需要处理的 Input 事件PendingInputEventQueue2.4…

《UNUX环境高级编程》(7)进程环境

1、引言 2、main函数 main函数的原型/*argc是命令行参数的数目&#xff0c;argv是指向各个指针所构成的数组*/ int main(int argc,char *argv[]);3、进程终止 有八种方式使进程终止。其中5种是正常&#xff0c;它们是&#xff1a; 从main函数返回调用exit调用_exit或_Exit最后…

ensp华为交换机基础命令

交换机基本操作命令 用户视图&#xff1a;基本操作&#xff0c;查看基本信息 系统视图&#xff1a;进阶操作&#xff0c;配置静态路由&#xff0c;动态路由 系统命令 system-view #用户视图模式切换至系统视图模式按q或者Ctrlz退出 display history-command #历史命令查询 …

华为防火墙基础操作

一、交代背景 做过一段时间的售后实施&#xff0c;发现防火墙设备上架真的很简单&#xff0c;但是对于没有入门的人来说也会有一定的不理解&#xff0c;初次上线我到底该做什么呢&#xff1f; 这个也是我当初开始做的时候的问题&#xff0c;当然了&#xff0c;初始都是利用web…