MoE专家模块Demo

server/2024/9/23 5:14:58/

文章目录

  • 前言
  • 一、MoE原理与设计原则
  • 二、构建完整transformers与MoE集成模块
  • 三、专家模块定义
  • 四、路由门控模块
  • 五、稀疏MoE集成模块
  • 六、完整MoE的Demo

前言

随着MoE越来越火热,MoE本质就是将Transformer中的FFN层替换成了MoE-layer,其中每个MoE-Layer由一个gate和若干个experts组成。这里gate和每个expert都可以理解成是nn.linear形式的神经网络。既然如此,本篇文章将结合transformers结构构建一个MoE的demo供大家学习。该源码可直接使用。

一、MoE原理与设计原则

expert:术业有专攻。假设我的输入数据是“我爱吃炸鸡”,在原始的Transformer中,我们把这5个token送去一个FFN层做处理。但是现在我们发现这句话从结构上可以拆成“主语-我”,“谓语-爱吃”,“宾语-炸鸡”,秉持着术业有专攻的原则,我把原来的1个FFN拆分成若干个expert,分别用来单独解析“主语”,“谓语”,“宾语”,这样可能会达到更好的效果。

gate:那么我怎么知道要把哪个token送去哪个expert呢?很简单,我再训练一个gate神经网络,让它帮我判断就好了。

当然,这里并不是说expert就是用来解析主谓宾,只是举一个例子说明:不同token代表的含义不一样,因此我们可以用不同expert来对它们做解析。除了训练上也许能达到更好的效果外,MoE还能帮助我们在扩大模型规模的同时保证计算量是非线性增加的(因为每个token只用过topK个expert,不用过全量expert),这也是我们说MoE-layer是稀疏层的原因。

最后需要注意的是,在之前的表述中,我们说expert是从FFN层转变而来的,这很容易让人错理解成expert就是对FFN的平均切分,实际上你可以任意指定每个expert的大小,每个expert甚至可以>=原来单个FFN层,这并不会改变MoE的核心思想:token只发去部分expert时的计算量会小于它发去所有expert的计算量。

引用:https://mp.weixin.qq.com/s/76a-7fDJumv6iB08L2BUKg

二、构建完整transformers与MoE集成模块

这个模块定义了一个名为Block的PyTorch模块,代表了一个混合专家Transformer块,包括多头自注意力和计算MoE部分(SparseMoE)。其源码如下:

class Block(nn.Module):"""Mixture of Experts Transformer block: communication followed by computation (multi-head self attention + SparseMoE) """def __init__(self, n_embed, n_head, num_experts, top_k):super().__init__()self.sa = nn.MultiheadAttention(n_embed, n_head)  # 我直接调用官网的attention方法self.smoe = SparseMoE(n_embed, num_experts, top_k)self.ln1 = nn.LayerNorm(n_embed)self.ln2 = nn.LayerNorm(n_embed)def forward(self, x):qkv = self.ln1(x)x = x + self.sa(qkv,qkv,qkv)[0]  # 并不包含FFN结构x = x + self.smoe(self.ln2(x))return x

这段代码的解读:
__init__方法:在初始化方法中,定义了模块的结构。模块包含了一个nn.MultiheadAttention实例sa用于多头自注意力计算,我直接使用pytorch官网的方法模块,一个SparseMoE实例smoe用于稀疏专家计算,以及两个nn.LayerNorm实例ln1和ln2用于层归一化。

forward方法:在前向传播方法中,首先对输入张量x进行层归一化处理,然后通过多头自注意力模块对处理后的张量进行注意力计算,并将注意力计算结果与原始输入张量相加。接着,将相加后的张量通过稀疏专家模块进行计算,再次与原始输入张量相加。最后返回处理后的张量。

总体来说,这段代码实现了一个混合专家Transformer块,结合了多头自注意力和稀疏专家计算。在前向传播过程中,通过多头自注意力和稀疏专家计算两部分对输入张量进行处理,并保持了张量的维度一致。

注: x = x + self.smoe(self.ln2(x))使用了类似残差方法

三、专家模块定义

专家模块定义了一个名为Expert的PyTorch模块,代表了一个专家模块,用于对输入进行线性变换和非线性变换。

class Expert(nn.Module):def __init__(self, n_embd):super().__init__()dropout=0.1self.net = nn.Sequential(nn.Linear(n_embd, 4 * n_embd),nn.ReLU(),nn.Linear(4 * n_embd, n_embd),nn.Dropout(dropout),)def forward(self, x):return self.net(x)

四、路由门控模块

门控网络,也称为路由,确定哪个专家网络接收来自多头注意力的 token 的输出。举个例子解释路由的机制,假设有 4 个专家,token 需要被路由到前 2 个专家中。首先需要通过线性层将 token 输入到门控网络中。该层将对应于(Batch size,Tokens,n_embed)的输入张量从(2,4,32)维度,投影到对应于(Batch size、Tokens,num_expert)的新形状:(2、4,4)。其中 n_embed 是输入的通道维度,num_experts 是专家网络的计数。

class NoisyTopkRouter(nn.Module):def __init__(self, n_embed, num_experts, top_k):super(NoisyTopkRouter, self).__init__()self.top_k = top_kself.topkroute_linear = nn.Linear(n_embed, num_experts)# add noiseself.noise_linear = nn.Linear(n_embed, num_experts)def forward(self, mh_output):# mh_ouput is the output tensor from multihead self attention blocklogits = self.topkroute_linear(mh_output)# Noise logitsnoise_logits = self.noise_linear(mh_output)# Adding scaled unit gaussian noise to the logitsnoise = torch.randn_like(logits) * F.softplus(noise_logits)noisy_logits = logits + noisetop_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)  # 在4个专家中选择最好的专家zeros = torch.full_like(noisy_logits, float('-inf'))sparse_logits = zeros.scatter(-1, indices, top_k_logits)router_output = F.softmax(sparse_logits, dim=-1)return router_output, indices

五、稀疏MoE集成模块

通过路由获得indices在对每个专家循环获得对应mask,挑选有用的信息不断叠加而获得最终结果,我已改成伪代码如下说明:

# 1. 输入进入router得到两个输出
gating_output, indices = router(x)
# 2.初始化全零矩阵,后续叠加为最终结果
final_output = zeros_like(x)
# 3.展平,即把每个batch拼接到一起,这里对输入x和router后的结果都进行了展平
flat_x = flatten(x)
flat_gating_output = flatten(gating_output)
# 4. 对每个专家进行操作
for each expert in experts:# 5. 查看当前专家对哪些tokens在前top_kexpert_mask = check_top_k_indices(indices, expert)# 6. 获取当前专家作用的token输入expert_input = select_expert_input(flat_x, expert_mask)# 7. 将token输入经过专家处理得到输出expert_output = expert(expert_input)# 8. 计算当前专家对作用的token的权重分数gating_scores = calculate_gating_scores(flat_gating_output, expert_mask)# 9. 将expert输出乘上权重分数weighted_output = expert_output * gating_scores# 10. 将结果叠加到最终输出中final_output += weighted_output
return final_output

总结解释:
这段代码实现了一个稀疏专家(SparseMoE)模块,结合了一个路由器(router)和多个专家(experts)。
输入首先通过路由器得到两个输出:一个是门控输出(gating_output),一个是索引(indices)。
然后初始化一个全零矩阵作为最终输出。
对每个专家进行操作,根据索引找出每个专家对哪些token起作用,然后将这些token输入到对应的专家中进行处理,根据门控输出的权重分配,将专家处理后的输出加权叠加到最终输出中。
最终返回加权叠加后的最终输出。

其源码如下:

class SparseMoE(nn.Module):def __init__(self, n_embed, num_experts, top_k):super(SparseMoE, self).__init__()self.router = NoisyTopkRouter(n_embed, num_experts, top_k)self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])self.top_k = top_kdef forward(self, x):# 1. 输入进入router得到两个输出gating_output, indices = self.router(x)  # x.shape=[2,6,64] gating_output.shape=[2,6,4]# 2.初始化全零矩阵,后续叠加为最终结果final_output = torch.zeros_like(x)  # final_output.shape=[2,6,64]# 3.展平,即把每个batch拼接到一起,这里对输入x和router后的结果都进行了展平flat_x = x.view(-1, x.size(-1))  # flat_x.shape=[12,64]flat_gating_output = gating_output.view(-1, gating_output.size(-1))  # flat_gating_output.shape=[12,4]# 以每个专家为单位进行操作,即把当前专家处理的所有token都进行加权for i, expert in enumerate(self.experts):# 4. 对当前的专家(例如专家0)来说,查看其对所有tokens中哪些在前top2expert_mask = (indices == i).any(dim=-1)  # expert_mask.shape=[2,6]# 5. 展平操作flat_mask = expert_mask.view(-1)  # flat_mask=12# 如果当前专家是任意一个token的前top2if flat_mask.any():# 6. 得到该专家对哪几个token起作用后,选取token的维度表示expert_input = flat_x[flat_mask]  # 假设第0个专家选择了2个batch的7个维度为[7,64]# 7. 将token输入expert得到输出expert_output = expert(expert_input)  # 这个专家选择通道走了mlp结构,[7,64]# 8. 计算当前专家对于有作用的token的权重分数gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)  # 使用flat_mask选择第i个专家的权重# 9. 将expert输出乘上权重分数weighted_output = expert_output * gating_scores# 10. 循环进行做种的结果叠加,也就是越重要被专家选择越多就叠加越多final_output[expert_mask] += weighted_output.squeeze(1)  # 变成[7,64]return final_output

六、完整MoE的Demo

最后,给出即插即用的MoE完整Demo代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass Expert(nn.Module):def __init__(self, n_embd):super().__init__()dropout=0.1self.net = nn.Sequential(nn.Linear(n_embd, 4 * n_embd),nn.ReLU(),nn.Linear(4 * n_embd, n_embd),nn.Dropout(dropout),)def forward(self, x):return self.net(x)class NoisyTopkRouter(nn.Module):def __init__(self, n_embed, num_experts, top_k):super(NoisyTopkRouter, self).__init__()self.top_k = top_kself.topkroute_linear = nn.Linear(n_embed, num_experts)# add noiseself.noise_linear = nn.Linear(n_embed, num_experts)def forward(self, mh_output):# mh_ouput is the output tensor from multihead self attention blocklogits = self.topkroute_linear(mh_output)# Noise logitsnoise_logits = self.noise_linear(mh_output)# Adding scaled unit gaussian noise to the logitsnoise = torch.randn_like(logits) * F.softplus(noise_logits)noisy_logits = logits + noisetop_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)  # 在4个专家中选择最好的专家zeros = torch.full_like(noisy_logits, float('-inf'))sparse_logits = zeros.scatter(-1, indices, top_k_logits)router_output = F.softmax(sparse_logits, dim=-1)return router_output, indicesclass SparseMoE(nn.Module):def __init__(self, n_embed, num_experts, top_k):super(SparseMoE, self).__init__()self.router = NoisyTopkRouter(n_embed, num_experts, top_k)self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])self.top_k = top_kdef forward(self, x):# 1. 输入进入router得到两个输出gating_output, indices = self.router(x)  # x.shape=[2,6,64] gating_output.shape=[2,6,4]# 2.初始化全零矩阵,后续叠加为最终结果final_output = torch.zeros_like(x)  # final_output.shape=[2,6,64]# 3.展平,即把每个batch拼接到一起,这里对输入x和router后的结果都进行了展平flat_x = x.view(-1, x.size(-1))  # flat_x.shape=[12,64]flat_gating_output = gating_output.view(-1, gating_output.size(-1))  # flat_gating_output.shape=[12,4]# 以每个专家为单位进行操作,即把当前专家处理的所有token都进行加权for i, expert in enumerate(self.experts):# 4. 对当前的专家(例如专家0)来说,查看其对所有tokens中哪些在前top2expert_mask = (indices == i).any(dim=-1)  # expert_mask.shape=[2,6]# 5. 展平操作flat_mask = expert_mask.view(-1)  # flat_mask=12# 如果当前专家是任意一个token的前top2if flat_mask.any():# 6. 得到该专家对哪几个token起作用后,选取token的维度表示expert_input = flat_x[flat_mask]  # 假设第0个专家选择了2个batch的7个维度为[7,64]# 7. 将token输入expert得到输出expert_output = expert(expert_input)  # 这个专家选择通道走了mlp结构,[7,64]# 8. 计算当前专家对于有作用的token的权重分数gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)  # 使用flat_mask选择第i个专家的权重# 9. 将expert输出乘上权重分数weighted_output = expert_output * gating_scores# 10. 循环进行做种的结果叠加,也就是越重要被专家选择越多就叠加越多final_output[expert_mask] += weighted_output.squeeze(1)  # 变成[7,64]return final_outputclass Block(nn.Module):"""Mixture of Experts Transformer block: communication followed by computation (multi-head self attention + SparseMoE) """def __init__(self, n_embed, n_head, num_experts, top_k):super().__init__()self.sa = nn.MultiheadAttention(n_embed, n_head)  # 我直接调用官网的attention方法self.smoe = SparseMoE(n_embed, num_experts, top_k)self.ln1 = nn.LayerNorm(n_embed)self.ln2 = nn.LayerNorm(n_embed)def forward(self, x):qkv = self.ln1(x)x = x + self.sa(qkv,qkv,qkv)[0]  # 并不包含FFN结构x = x + self.smoe(self.ln2(x))return xif __name__ == '__main__':# 假设分词64个embed与4个头n_embed, n_head = 64, 4dropout = 0.1# 假设4个专家与2个top_knum_experts, top_k = 4, 2# 假设数据是2个batch、6个分词,后面代码注释都已这些假设具体化inputs = torch.rand(2, 6, n_embed)M = Block(n_embed, n_head, num_experts, top_k)y = M(inputs)print(y.shape)

http://www.ppmy.cn/server/96171.html

相关文章

fastjson-小于1.2.47绕过

参考视频&#xff1a;fastjson反序列化漏洞3-<1.2.47绕过_哔哩哔哩_bilibili 分析版本 fastjson1.2.24 JDK 8u141 分析流程 分析fastjson1.2.25更新的源码&#xff0c;用JsonBcel链跟进 先看修改的地方 fastjson1.2.24 if (key JSON.DEFAULT_TYPE_KEY && !…

网络安全之sql靶场(11-23)

sql靶场&#xff08;11-23&#xff09; 目录 第十一关&#xff08;post注入&#xff09; 第十二关 第十三关 第十四关 第十五关 第十六关 第十七关 第十八关 第十九关 第二十关 第二十一关 第二十二关 第二十三关 第十一关&#xff08;post注入&#xff09; 查看…

flink 1.17 测试

1、下面配置配置错误的话会导致flink任务无法连接resourcemanger job会提交失败&#xff0c;报错内容&#xff1a; flink on yarn Connecting to ResourceManager at /0.0.0.0:8030 解决方案&#xff1a; <property> <name>yarn.application.classpath&…

搭建PXE实现服务器自动部署

PXE&#xff08;Preboot Execution Environment&#xff09;是一种计算机启动技术&#xff0c;它允许计算机从网络上的服务器而不是从本地硬盘或光盘等存储介质上启动。这种技术主要应用在无盘工作站、网络安装操作系统、远程维护等方面。 环境&#xff1a;一台rhel7.9作为PXE…

简单的docker学习 第8章 docker常用服务安装

第8章 常用服务安装 本章主要学习最常用的&#xff0c;也是安装起来稍有些麻烦的 MySQL 与 Redis 两种服务器的Docker 安装。至于其它服务器的 Docker 安装&#xff0c;大家可自行查找资料。只要 MySQL 与 Redis这两类服务器学会了安装&#xff0c;其它服务器的安装基本也不会…

datawind可视化查询-计数count(xxx)函数

飞书官方文档:https://www.volcengine.com/docs/4726/47275 我用到的场景:统计某个埋点的数量 格式:count(xxx),即对 xxx 计数 示例: 字段A 1 1 3 4 计算count(字段A),得到聚合结果 4。 若想去重计数,可使用count(distinct 字段A),则得到结果 3。 功能详解 函数名…

pxe+kickstart自动化安装

目录 一&#xff1a;实验环境 一台红帽7主机 开启主机图形 init 5 开图形 配置网络可用 关闭vmware dhcp功能 安装httpd服务 1、安装可视化图形&#xff1a; 2、关闭vmware dhcp功能&#xff1a; 3、安装httpd服务 安装httpd 开启httpd 二&#xff1a;实验过程 …

从零开始写一个微信小程序

从零开始写一个微信小程序可以分为几个步骤。以下是一个详细的指南,帮助你从头到尾完成一个简单的微信小程序。 ### 一、准备工作 1. **注册微信小程序账号**: - 前往[微信公众平台](https://mp.weixin.qq.com/)注册一个小程序账号。 - 进行企业认证(个人账号需要申…