Tiny Transformer:从零开始构建简化版Transformer模型

news/2024/10/4 6:57:46/
引言

        自然语言处理(NLP)与计算机视觉(CV)有显著差异,各自任务的独特性决定了它们适用的模型架构。在CV中,卷积神经网络(CNN)长期占据主导地位,而在NLP领域,循环神经网络(RNN)和长短期记忆网络(LSTM)曾是主流。然而,这些传统模型在处理长序列时效率较低,难以捕捉长期依赖关系。

        针对这些问题,Vaswani等人在2017年提出了一种全新的、完全基于注意力机制的模型——Transformer。该模型解决了RNN串行计算的效率问题,并通过自注意力机制有效处理了长序列的长期依赖问题。本文将带领大家一步步构建一个简化版的Transformer模型,称之为Tiny Transformer,帮助大家深入理解其工作原理。

1. 注意力机制

        Transformer的核心是注意力机制,它通过计算Query、Key和Value之间的相关性,动态地为不同位置分配注意力权重。我们将通过多头注意力机制(Multi-Head Attention)来扩展这种计算,以便模型能同时关注多个不同的相关性。

1.1 什么是Attention?

        Attention机制通过计算Query(查询向量)与Key(键向量)之间的相似度来为Value(值向量)加权求和。它的本质是根据当前输入的每个词与其他词的相关性动态调整注意力分布。

        例如,给定一个句子,我们可以通过Attention机制来计算每个词对其他词的关注程度。Attention公式如下:

1.2 Multi-Head Attention

        多头注意力机制扩展了单头注意力的概念,通过并行化多个注意力头来捕获序列中不同层次的相关性。每个注意力头对输入进行独立的Attention计算,然后将所有头的输出拼接起来,形成最终的输出。

import torch.nn as nn
import torch
import torch.nn.functional as Fclass MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()self.num_heads = num_headsself.d_model = d_modelself.head_dim = d_model // num_headsassert self.head_dim * num_heads == d_model, "d_model must be divisible by num_heads"self.qkv = nn.Linear(d_model, 3 * d_model)self.fc_out = nn.Linear(d_model, d_model)def forward(self, x):B, T, C = x.shapeqkv = self.qkv(x).reshape(B, T, 3, self.num_heads, self.head_dim)q, k, v = qkv.permute(2, 0, 3, 1, 4)attn_scores = (q @ k.transpose(-2, -1)) * (1.0 / torch.sqrt(k.size(-1)))attn_weights = F.softmax(attn_scores, dim=-1)attn_output = (attn_weights @ v).transpose(1, 2).reshape(B, T, C)return self.fc_out(attn_output)
2. 编码器和解码器

        Transformer的结构包括编码器(Encoder)和解码器(Decoder),二者均由多层的注意力机制和前馈神经网络(Feed-Forward Neural Network, FFN)组成。

2.1 编码器

        编码器的主要任务是对输入序列进行编码,并生成上下文表示供解码器使用。每个编码器层包括一个自注意力层和一个前馈网络。

class EncoderLayer(nn.Module):def __init__(self, d_model, num_heads, ff_hidden_dim, dropout=0.1):super(EncoderLayer, self).__init__()self.mha = MultiHeadAttention(d_model, num_heads)self.ffn = nn.Sequential(nn.Linear(d_model, ff_hidden_dim),nn.ReLU(),nn.Linear(ff_hidden_dim, d_model))self.layernorm1 = nn.LayerNorm(d_model)self.layernorm2 = nn.LayerNorm(d_model)self.dropout = nn.Dropout(dropout)def forward(self, x):attn_output = self.mha(x)x = self.layernorm1(x + self.dropout(attn_output))ffn_output = self.ffn(x)return self.layernorm2(x + self.dropout(ffn_output))
2.2 解码器

        解码器的结构与编码器类似,但它包含了一个额外的“交叉注意力”层,用于将编码器的输出作为上下文信息输入,结合解码器自身的输入进行生成。

class DecoderLayer(nn.Module):def __init__(self, d_model, num_heads, ff_hidden_dim, dropout=0.1):super(DecoderLayer, self).__init__()self.mha1 = MultiHeadAttention(d_model, num_heads)self.mha2 = MultiHeadAttention(d_model, num_heads)self.ffn = nn.Sequential(nn.Linear(d_model, ff_hidden_dim),nn.ReLU(),nn.Linear(ff_hidden_dim, d_model))self.layernorm1 = nn.LayerNorm(d_model)self.layernorm2 = nn.LayerNorm(d_model)self.layernorm3 = nn.LayerNorm(d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, enc_out):attn_output1 = self.mha1(x)x = self.layernorm1(x + self.dropout(attn_output1))attn_output2 = self.mha2(x, enc_out, enc_out)x = self.layernorm2(x + self.dropout(attn_output2))ffn_output = self.ffn(x)return self.layernorm3(x + self.dropout(ffn_output))
3. 位置编码

        Transformer由于完全摒弃了递归结构,不能自然捕捉输入序列中的位置信息。因此,位置编码(Positional Encoding)被引入,用于为每个词添加位置信息。位置编码通过正弦和余弦函数为不同位置生成独特的表示。

import math
import torchclass PositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000):super(PositionalEncoding, self).__init__()pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0)self.register_buffer('pe', pe)def forward(self, x):return x + self.pe[:, :x.size(1)]
4. 完整的Transformer模型

        有了上面各个模块后,我们可以将它们组合成一个完整的Transformer模型。该模型包括一个嵌入层、多个编码器层、解码器层以及一个线性层用于生成输出。

class Transformer(nn.Module):def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_encoder_layers, num_decoder_layers, ff_hidden_dim, dropout):super(Transformer, self).__init__()self.src_embedding = nn.Embedding(src_vocab_size, d_model)self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)self.positional_encoding = PositionalEncoding(d_model)self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, ff_hidden_dim, dropout) for _ in range(num_encoder_layers)])self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, ff_hidden_dim, dropout) for _ in range(num_decoder_layers)])self.fc_out = nn.Linear(d_model, tgt_vocab_size)def forward(self, src, tgt):src = self.positional_encoding(self.src_embedding(src))tgt = self.positional_encoding(self.tgt_embedding(tgt))for layer in self.encoder_layers:src = layer(src)for layer in self.decoder_layers:tgt = layer(tgt, src)return self.fc_out(tgt)
结语

        本文通过逐步实现简化版的Transformer,展示了Transformer模型的核心组成部分——多头注意力、编码器-解码器架构、位置编码等。通过这些模块,Transformer能够高效处理序列数据,实现并行计算,广泛应用于自然语言处理、机器翻译等任务。

        Transformer的灵活性和强大的性能使其成为现代深度学习的基石。在掌握了这些基本模块后,大家可以进一步研究更复杂的模型,如BERT、GPT等预训练模型,以更好地理解和应用Transformer在实际任务中的强大能力。

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!


http://www.ppmy.cn/news/1534278.html

相关文章

Nagle 算法:优化 TCP 网络中小数据包的传输

1. 前言 在网络通信中,TCP(传输控制协议)是最常用的协议之一,广泛应用于各种网络应用,如网页浏览、文件传输和在线游戏等。然而,随着互联网的普及,小数据包的频繁传输成为一个不容忽视的问题。…

mac配置python出现DataDirError: Valid PROJ data directory not found错误的解决

最近在利用python下载SWOT数据时出现以下的问题: import xarray as xr import s3fs import cartopy.crs as ccrs from matplotlib import pyplot as plt import earthaccess from earthaccess import Auth, DataCollections, DataGranules, Store import os os.env…

C++学习,信号处理

C信号处理,依赖于操作系统提供的API。信号处理主要用于响应外部事件,如中断信号(如SIGINT, SIGTERM等),这些信号可以由操作系统、其他程序或用户生成。 在Unix-like系统(如Linux和macOS)中&…

three.js 通过着色器实现热力图效果

three.js 通过着色器实现热力图效果 在线预览 https://threehub.cn/#/codeMirror?navigationThreeJS&classifyshader&idheatmapShader 在 https://threehub.cn 中还有很多案例 <!doctype html> <html lang"en"> <head> <meta charse…

RabbitMQ 界面管理说明

1.RabbitMQ界面访问端口和后端代码连接端口不一样 界面端口是15672 http://localhost:15672/ 后端端口是 5672 默认账户密码登录 guest 2.总览图 3.RabbitMq数据存储位置 4.队列 4.客户端消费者连接状态 5.队列运行状态 6.整体运行状态

基于大数据技术的音乐数据分析及可视化系统

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码 精品专栏&#xff1a;Java精选实战项目…

“衣依”服装销售平台:Spring Boot框架的设计与实现

3系统分析 3.1可行性分析 通过对本“衣依”服装销售平台实行的目的初步调查和分析&#xff0c;提出可行性方案并对其一一进行论证。我们在这里主要从技术可行性、经济可行性、操作可行性等方面进行分析。 3.1.1技术可行性 本“衣依”服装销售平台采用JAVA作为开发语言&#xff…

Android Glide(一):源码分析,内存缓存和磁盘缓存的分析,实现流程以及生命周期

目录 一、Android Glide是什么&#xff0c;如何使用&#xff1f; Android Glide是一个由Google维护的快速高效的Android图像加载库&#xff0c;它旨在简化在Android应用程序中加载和显示图像的过程&#xff0c;包括内存缓存、磁盘缓存和网络加载&#xff0c;以确保图像加载的快…