【深度学习】常见模型-Transformer模型

news/2025/1/30 14:13:21/

Transformer 是一种深度学习模型,首次由 Vaswani 等人在 2017 年提出(论文《Attention is All You Need》),在自然语言处理(NLP)领域取得了革命性成果。它的核心思想是通过 自注意力机制(Self-Attention Mechanism) 和完全基于注意力的架构来捕捉序列数据中的依赖关系。


Transformer 的基本结构

Transformer 模型由两个主要模块组成:

  1. 编码器(Encoder)

    • 输入序列经过嵌入(Embedding)和位置编码(Positional Encoding)后,逐层通过多个编码块。
    • 每个编码块包括两个主要子层:
      1. 多头自注意力层(Multi-Head Self-Attention)。
      2. 前馈全连接网络(Feedforward Neural Network)。
  2. 解码器(Decoder)

    • 解码器也由多层解码块组成,结构类似编码器,但有额外的交叉注意力机制。
    • 解码块主要包含:
      1. 多头自注意力层(Masked Multi-Head Self-Attention)。
      2. 交叉注意力层(Encoder-Decoder Attention)。
      3. 前馈全连接网络。

Transformer 的输入经过编码器进行特征提取,解码器利用编码器输出生成目标序列。


核心组件

1. 自注意力机制(Self-Attention Mechanism)
  • 目标:在序列的每个位置,计算它与其他所有位置的相关性,捕获全局依赖关系。
  • 公式

    \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V
    • Q:查询矩阵(Query)。
    • K:键矩阵(Key)。
    • V:值矩阵(Value)。
    • d_k:键向量的维度(用于缩放防止梯度爆炸)。
2. 多头注意力机制(Multi-Head Attention)
  • 将输入数据分为多个头(head),并分别计算注意力。
  • 优点:能够从不同的子空间捕获特征,提高模型的表达能力。
3. 位置编码(Positional Encoding)
  • 因为 Transformer 不使用 RNN 或 CNN,所以需要显式地表示序列位置。
  • 常用正弦和余弦函数来表示:

    PE(pos, 2i) = \sin\left(\frac{pos}{10000^{2i/d}}\right), \quad PE(pos, 2i+1) = \cos\left(\frac{pos}{10000^{2i/d}}\right)
    • pos:位置索引。
    • i:维度索引。
    • d:嵌入维度。
4. 前馈全连接网络(FFN)
  • 每个编码器或解码器块都包含一个独立的全连接网络:

    FFN(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2
5. 残差连接与层归一化
  • 每个子层后加残差连接(Residual Connection)并归一化(Layer Normalization),以加速训练和稳定梯度。

Transformer 的整体结构

Transformer 使用堆叠的编码器和解码器模块处理输入和输出:

  1. 输入序列(如句子)经过嵌入和位置编码后输入到编码器。
  2. 编码器生成的上下文向量传递到解码器。
  3. 解码器通过交叉注意力结合编码器的上下文向量和解码器中间状态生成输出序列。

代码实现

以下是使用 TensorFlow 和 Keras 构建简单 Transformer 的代码示例:

import tensorflow as tf
from tensorflow.keras.layers import Dense, Embedding, LayerNormalization, Dropout
import numpy as np# 自注意力机制
class MultiHeadAttention(tf.keras.layers.Layer):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()self.num_heads = num_headsself.d_model = d_modelassert d_model % self.num_heads == 0self.depth = d_model // self.num_headsself.wq = Dense(d_model)self.wk = Dense(d_model)self.wv = Dense(d_model)self.dense = Dense(d_model)def split_heads(self, x, batch_size):x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))return tf.transpose(x, perm=[0, 2, 1, 3])  # (batch_size, num_heads, seq_len, depth)def call(self, q, k, v, mask):batch_size = tf.shape(q)[0]q = self.wq(q)  # (batch_size, seq_len, d_model)k = self.wk(k)v = self.wv(v)q = self.split_heads(q, batch_size)k = self.split_heads(k, batch_size)v = self.split_heads(v, batch_size)# Scaled dot-product attentionmatmul_qk = tf.matmul(q, k, transpose_b=True)dk = tf.cast(tf.shape(k)[-1], tf.float32)scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)if mask is not None:scaled_attention_logits += (mask * -1e9)attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (batch_size, num_heads, seq_len_q, seq_len_k)output = tf.matmul(attention_weights, v)  # (batch_size, num_heads, seq_len_q, depth_v)output = tf.transpose(output, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)concat_attention = tf.reshape(output, (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)return self.dense(concat_attention)# 示例调用
sample_mha = MultiHeadAttention(d_model=512, num_heads=8)
temp_q = tf.random.uniform((1, 60, 512))  # (batch_size, seq_len, d_model)
temp_k = tf.random.uniform((1, 60, 512))
temp_v = tf.random.uniform((1, 60, 512))
temp_out = sample_mha(temp_q, temp_k, temp_v, None)
print(temp_out.shape)  # (1, 60, 512)


Transformer 的应用

  1. 自然语言处理

    • 机器翻译(Google Translate 使用 Transformer)。
    • 文本摘要(如 BERT、GPT)。
    • 情感分析、问答系统。
  2. 计算机视觉

    • 图像分类(如 Vision Transformer)。
    • 目标检测、图像生成。
  3. 音频处理

    • 语音识别(如 Wav2Vec)。
    • 音乐生成。
  4. 其他领域

    • 推荐系统、时间序列预测、生物信息学。

优点与缺点

优点:
  1. 并行处理能力强,速度快。
  2. 能捕获长距离依赖关系。
  3. 通用性强,适用于多种任务。
缺点:
  1. 计算成本高(尤其是自注意力机制在长序列上的时间复杂度)。
  2. 对内存需求大,训练大型模型需高性能硬件。

Transformer 以其强大的表达能力和灵活性,已经成为深度学习领域的重要基石,为 NLP 和其他领域带来了巨大变革。


http://www.ppmy.cn/news/1567904.html

相关文章

芸众商城小程序会员页面部分图标不显示问题解决办法

我遇到的问题 如下图所示,会员中心这里的图标在小程序端显示异常。但是在网页端又是能够正常显示的。 小程序端截图: 网页端截图: 我的解决方法 检查使用的小程序版本,比如这里使用的是1.2.238版本的小程序,最后…

7-Zip Mark-of-the-Web绕过漏洞复现(CVE-2025-0411)

免责申明: 本文所描述的漏洞及其复现步骤仅供网络安全研究与教育目的使用。任何人不得将本文提供的信息用于非法目的或未经授权的系统测试。作者不对任何由于使用本文信息而导致的直接或间接损害承担责任。如涉及侵权,请及时与我们联系,我们将尽快处理并删除相关内容。 0x0…

漏洞修复:Apache Tomcat 安全漏洞(CVE-2024-50379) | Apache Tomcat 安全漏洞(CVE-2024-52318)

文章目录 引言I Apache Tomcat 安全漏洞(CVE-2024-50379)漏洞描述修复建议升级Tomcat教程II Apache Tomcat 安全漏洞(CVE-2024-52318)漏洞描述修复建议III 安全警告引言 解决方案:升级到最新版Tomcat https://blog.csdn.net/z929118967/article/details/142934649 service in…

Origami Agents:AI驱动的销售研究工具,助力B2B销售团队高效增长

在竞争激烈的B2B市场中,销售团队面临着巨大的挑战——如何高效地发现潜在客户并进行精准的外展活动。Origami Agents通过其创新的AI驱动研究工具,正在彻底改变这一过程。本文将深入探讨Origami Agents的产品特性、技术架构及其快速增长背后的成功因素。 一、一句话定位 Ori…

CF1098F Ж-function

【题意】 给你一个字符串 s s s,每次询问给你 l , r l, r l,r,让你输出 s s s l , r sss_{l,r} sssl,r​中 ∑ i 1 r − l 1 L C P ( s s i , s s 1 ) \sum_{i1}^{r-l1}LCP(ss_i,ss_1) ∑i1r−l1​LCP(ssi​,ss1​)。 【思路】 和前一道题一样&#…

Django-Admin WebView 集成项目技术规范文档 v2.1

Django-Admin WebView 集成项目技术规范文档 v2.1 系统架构规范 1.1 技术栈要求 前端框架:Flutter: 3.27.1 (空安全版本)Dart: 3.3.1 (支持元编程)webview_flutter: ^4.10.0 (带Hybrid Composition支持)后端要求:Django: 4.2.x LTS (安全支持至2026)Python: 3.11.x (启用PEP …

Formality:时序变换(二)(不可读寄存器移除)

相关阅读 Formalityhttps://blog.csdn.net/weixin_45791458/category_12841971.html?spm1001.2014.3001.5482 一、引言 时序变换在Design Compiler的首次综合和增量综合中都可能发生,它们包括:时钟门控(Clock Gating)、寄存器合并(Register Merging)、…

步进电机加减速公式推导

运动控制梯形速度曲线相关算法请参考下面系列文章 PLC运动控制基础系列之梯形速度曲线_三菱运动控制模块梯形加减速-CSDN博客文章浏览阅读3.1k次,点赞3次,收藏7次。本文是关于PLC运动控制的基础教程,重点介绍了梯形速度曲线的概念、计算和应用。讨论了梯形加减速在启动和停…