Transformer解析——(四)Decoder

news/2025/2/22 19:28:21/

 本系列已完结,全部文章地址为:

Transformer解析——(一)概述-CSDN博客

Transformer解析——(二)Attention注意力机制-CSDN博客

Transformer解析——(三)Encoder-CSDN博客

Transformer解析——(四)Decoder-CSDN博客

Transformer解析——(五)代码解析及拓展-CSDN博客

Decoder与Encoder的结构非常类似,下面具体解析。

1 整体结构

与Encoder相比,Decoder增加了一个交叉注意力的模块,其他模块的结构与Encoder类似。

2 Decoder的训练和预测过程

注意,Decoder的训练和预测差别很大。

在训练时,将目标序列(比如在英译中任务里,目标序列就是中文)通过Input Embedding和Positional Encoding后,并行输入到自注意力模块中,注意使用Look Ahead Mask避免训练第i个词时使用了i以后的词。自注意力的输出作为Q,Encoder输出作为K和V,一并输入到交叉注意力模块中,随后经过ResNet等模块,得到最终的输出。

在预测时,交叉注意力模块的K和V依然来自于Encoder,Decoder的输入先从开始标志<BEGIN>开始,输出第一个词后,将<BEGIN>,第一个词合在一起重新输入到Decoder,序贯地输出每一个词,直到输出<END>结束标志位停止。

至于这么设计的原因,可见"Look Ahead Mask"小节。

3 交叉注意力

交叉注意力将Encoder的输出传入Attention作为Key和Value,为什么这么设计呢?因为Decoder在生成目标序列时不止要考虑已生成的序列,还要考虑源序列的信息。正是交叉注意力模块连接了Encoder和Decoder

通过使用Encoder输出的 K 和 V,Decoder可以知道源序列中哪些部分与当前正在生成的目标位置最相关。例如,在英译汉任务"I want to go to school"中,假设现在已经输出了"我想去",在输出下一个词时,如果只考虑已输出的“我想去”,肯定是没法输出学校的。通过将"我想去"与原始序列"I want to go to school"计算Attention,最终识别出和"school"的关注度更高,输出"学校"。交叉注意力机制保证了在翻译每一个中文词时,都需要参考英文原文中的相关内容,以确保翻译的准确性。

4 Look Ahead Mask

在预测时,模型是从左到右的顺序依次生成每个位置的输出的,通过第一个token预测第二个token,通过第一个和第二个token预测第三个token,以此类推,不能提前获取未来位置的信息。因此在训练时,也要遵守同样的行为,否则训练和预测的行为逻辑不一致会影响效果。

就像军事演习时如果每次都提前告知敌人方位,那么真正走上战场就抓瞎了。

在训练时,虽然我们是可以提前获取目标序列的,但是也要人为控制模型在输出某个位置的token时不可以参考后文,防止模型作弊提前得到未来的信息。

在代码实现上,Look Ahead Mask表现为一个上三角矩阵(实际上是方阵),该矩阵右上部分都是1,对角线及左下部分都是0。1代表需要掩盖的位置,0代表不需要掩盖的位置。该矩阵乘一个无穷小的数字,如果矩阵中元素是1,则乘无穷小的数字后将变为无穷小;如果矩阵元素是0,则乘无穷小的数字后将变为0。

该矩阵将叠加到计算好的注意力矩阵,因此在1的位置注意力将是无穷小,在0的位置注意力将不变,从而掩盖了元素是1的位置。与注意力矩阵一样,m行n列元素代表第m个token对第n个token的注意力,比如我们看第3行,前3列都是0,表示它只能利用前3个token(包含自身,因此对角线是0)的信息,第4个token及以后位置都是1。

训练时不需要串行地将目标序列逐个输入到Decoder,而是可以一次性输入全部target,通过Look Ahead Mask控制Decoder的多头注意力不使用未来信息。

以汉译英为例,串行是指先用“<BEGIN>”预测“I”,更新权重,然后用“<BEGIN> I”预测“have”,以此类推。Transformer是直接将目标输出“<BEGIN> I have a cat <END>”全部输入到Decoder中,并行训练参数。

注意,预测时Transformer并不是并行的,必须等上一个token输出完,再拿着已生成的token预测下一个词。

5 线性输出

将输出线性变换,将词向量维度升格为词汇表维度,便于从词汇表维度中通过概率选词。

6 Temperature温度

Temperature控制了选词的创新性。Temperature即Creativity,温度越大,概率会更加平均,选择次高概率词的可能性更大,多样性大;温度越低,概率最高的词概率将更高,选择次高概率词的可能性更小,多样性小。

具体计算上,在softmax运算前先对样本除以了温度。

原始的softmax公式为

softmax(z_i)=e^{z_i}/\sum_j{e^{z_j}}

其中z为原始分数,通过softmax将各输出的分数之和固定为1,将分数转化为概率。

修改后的softmax公式为

softmax(z_i)=e^{z_i/T}/\sum_j{e^{z_j}/T}

若T等于1,则等价于原始的softmax;若T>1,则“强者更弱,弱者更强”,概率更平滑;若T<1,则"强者更强,弱者更弱",概率更尖锐,更偏向于概率高的结果。

下面简要说明T可以控制概率分布平滑还是尖锐。

可以考虑两个样本1和2,原始分数分别为a和b。a>b,因此样本1与样本2的概率之比为(e^a)/(e^b)=e^(a-b),记为d1。概率之比越大,表示分布越尖锐。比如概率之比是2,则概率分别是66%和33%;概率之比是9,则概率分别是90%和10%。

当除以T后,两个样本的概率之比为e^[(a-b)/T],记为d2。

若T>1,d2<d1,即两样本概率之比缩小,因此概率分布更平滑;

若T=1,d2=d1,即两样本概率之比不变;

若T<1,d2>d1,即两样本概率之比变大,因此概率分布更尖锐。


http://www.ppmy.cn/news/1574235.html

相关文章

Linux 内核网络设备驱动编程:私有协议支持

一、struct net_device的通用性与私有协议的使用 struct net_device是Linux内核中用于描述网络设备的核心数据结构,它不仅限于TCP/IP协议,还可以用于支持各种类型的网络协议,包括私有协议。其原因如下: 协议无关性:struct net_device的设计是通用的,它本身并不依赖于任何…

【进阶】Java设计模式详解

java注解 什么是注解&#xff1f; java中注解(Annotation)&#xff0c;又称java标注&#xff0c;是一种特殊的注释。 可以添加在包&#xff0c;类&#xff0c;成员变量&#xff0c;方法&#xff0c;参数等内容上面&#xff0c;注解会随同代码被编译到字节码文件中&#xff0…

《重构-》

一、代码坏的味道 神秘命名 ​​​​​代码应该直观明了。要深思熟虑如何给函数、模块、变量和类命名&#xff0c;使它们能清晰地表明 自己的功能和用法。 重复代码 一旦有重复代码存在&#xff0c;阅读这些重复的代码时你就必须加倍仔细&#xff0c;留意其间细微的差异。如果…

EasyExcel 自定义头信息导出

需求&#xff1a;需要在导出 excel时&#xff0c;合并单元格自定义头信息(动态生成)&#xff0c;然后才是字段列表头即导出数据。 EasyExcel - 使用table去写入&#xff1a;https://easyexcel.opensource.alibaba.com/docs/current/quickstart/write#%E4%BD%BF%E7%94%A8table%E…

TikTok账户安全指南:如何取消两步验证?

TikTok账户安全指南&#xff1a;如何取消两步验证&#xff1f; 在这个数字化的时代&#xff0c;保护我们的在线账户安全变得尤为重要。TikTok&#xff0c;作为全球流行的社交媒体平台&#xff0c;其账户安全更是不容忽视。两步验证作为一种增强账户安全性的措施&#xff0c;虽…

高斯积分的证明

内容来源 B站视频BV1LC4y1P7gM 高斯积分 ∫ 0 ∞ e − x 2 d x \int^\infty_0e^{-x^2}\mathcal{d}x ∫0∞​e−x2dx 添加新元 设 f ( t ) [ ∫ 0 t e − x 2 d x ] 2 f(t)\left[\int^t_0e^{-x^2}\mathcal{d}x\right]^2 f(t)[∫0t​e−x2dx]2 现目标 求 lim ⁡ t → ∞ f …

Ubuntu 下 nginx-1.24.0 源码分析 - ngx_test_full_name

ngx_test_full_name 声明在 src\core\ngx_file.c static ngx_int_t ngx_test_full_name(ngx_str_t *name); 定义在 src\core\ngx_file.c static ngx_int_t ngx_test_full_name(ngx_str_t *name) { #if (NGX_WIN32)u_char c0, c1;c0 name->data[0];if (name->len <…

SPRING10_SPRING的生命周期流程图

经过前面使用三大后置处理器BeanPostProcessor、BeanFactoryPostProcessor、InitializingBean对创建Bean流程中的干扰,梳理出SPRING的生命周期流程图如下