深度学习 - 45.MMOE Gate 简单实现 By Keras

news/2025/2/5 15:11:33/

目录

一.引言

二.MMoE 模型分析

三.MMoE 逻辑实现

• Input

• Expert Output

• Gate Output

• Weighted Sum

• Sigmoid Output

• 完整代码

四.总结


一.引言

上一篇文章介绍了 MMoE 借鉴 MoE 的思路,为每一类输出构建一个 Gate 并最终加权多个 Expert 的输出,提高了相关性不高的多任务问题,下面根据思路简易实现下 MMoE 逻辑。

 

二.MMoE 模型分析

上图为 BaseLine、MoE 与 MMoE,下面我们专门了解 MMoE 的实现细节,其中 bs 为 BatchSize,Hidden_units 代表隐层输出维度,E0、E1、E2 代表多个 Expert,Tower A、Tower B 代表 N 个输出。

这里 Gate 与 Expert 都可以理解为浅层模型。 下面针对 bs 里的一个样本,分析下 MMoE 执行过程。

• Input 

输入一个 embedding_size 长度的向量,这里也可以是 Filed 个变量,进入对应 Embedding 层 Lookup 再做 pooling,为了简单,这里直接取 1 x F,F 为 Field Size。

• Expert Output

输入向量分别经过 E1、E2、... 的 K 个 Expert 计算,得到 Expert 个输出,每个输出维度为 1 x Hidden。

• Gate output

输入向量分别经过 G1、G2、... 的 K 个 Gate 计算,得到 Expert 个输出,维度为 1 x expert,每一维代表对应 Expert Output 的权重,注意这里 Gate 最终输出需经过一层 softmax。

• Weighted Sum

将当前样本输出的 expert 分别与对应 Expert Output 加权求和作为每个任务的 Tower 的输入向量,合并后维度为 1 x hidden_units。

• Sigmoid Output

每个 Tower 最终为 sigmoid 输出的二分类深度模型,输入为 Expert x Hidden 的加权求和,维度仍然为 1 x Hidden。

Tips:

实际 BatchSize 执行时,将上述 1 x ... 换为 bs x ... 即对应 Batch 的逻辑。

三.MMoE 逻辑实现

参数设置为:

    num_field = 4hidden_units = 8num_expert = 3num_output = 2

4 个 Field 域、8 维输出、3 个 Expert、2 个任务。

• Input

这里 N 就是上面的 bs 即 Batch Size,F 为 Field Size,这里实现逻辑比较简单,实践场景下可以先将 Field Lookup 得到 Embedding 再 pooling 输入到后面的 Expert 和 Gate:

    # 1.构造 Input => N x Fnum_samples = 10inputs = np.array([np.ones(shape=num_field) for i in range(num_samples)])print("Input Shape:", inputs.shape)
Input Shape: (10, 4)

• Expert Output

    # 2.构建专家 kernel [(filed * hidden) x expert]expert_kernels = np.random.random(size=(num_field, hidden_units, num_expert))print("Expert Shape:", expert_kernels.shape)# 3.获取 Expert 输出 => [N x F] * [F x Hidden x Expert] = N x Hidden x Expertoutputs_by_expert = tf.tensordot(inputs, expert_kernels, axes=1)print("Output By Expert:", outputs_by_expert.shape)
Expert Shape: (4, 8, 3)
Output By Expert: (10, 8, 3)

• Gate Output

Gate:在 Dense 输出基础上,增加 softmax 逻辑。

class Gate(Layer):def __init__(self, expert_num, **kwargs):self.expert_num = expert_numself.gate = Nonesuper(Gate, self).__init__(**kwargs)def build(self, input_shape):self.gate = Dense(self.expert_num, activation='relu', kernel_initializer=glorot_normal_initializer)super(Gate, self).build(input_shape)def call(self, _inputs, **kwargs):weight = self.gate(_inputs)_output = tf.nn.softmax(weight)return _output

 根据 num_output 任务输出数,决定构造 Gate 的数量,这是 MMoE 与 MoE 的区别之一。

    # 4.构建 Gategate = [Gate(num_expert) for i in range(num_output)]# 5.获取每个任务的 Gate 输出权重 output x N x expertoutputs_by_gate = np.array([gate[i](inputs) for i in range(num_output)])print("Output By Gate:", outputs_by_gate.shape)
Output By Gate: (2, 10, 3)

 

• Weighted Sum

    # 6.获取最终输出 N x outputpart_input = []for output_by_gate in outputs_by_gate:# N x Expert => N x 1 x Expertexpand_output_by_gate = tf.expand_dims(output_by_gate, axis=1)# N x 1 x Expert => N x Hidden x Expertrepeat_gate_weight = K.repeat_elements(expand_output_by_gate, hidden_units, axis=1)# N x Hidden x Expertweighted_expert_output = tf.cast(outputs_by_expert, dtype='float32') * repeat_gate_weight# N x Hiddenweighted_expert_sum = tf.reduce_sum(weighted_expert_output, axis=2)part_input.append(weighted_expert_sum)print("Part Input:", np.array(part_input).shape)

原始输入样本个数 N=10,输出 hidden_size=8,任务有2个,所以每个任务获得 BS x hidden_size 即 10 x 8 的 batch 样本。 

Part Input: (2, 10, 8)

 

• Sigmoid Output

这里两个任务对应两个 Tower,之前介绍了多输出模型:TF x Keras 之多输出模型

任务架构基于 Shared-bottom Multi-task Model,实现了同时预测年龄、收入、性别的多分类问题,有兴趣的同学可以把 Shared-bottom 的架构切换为多个 Expert 再加入 Gate 即可实现基础的 MMoE,这里就不再展开了。

 

• 完整代码

import numpy as np
import tensorflow as tf
from tensorflow.python.keras.layers import *
from tensorflow.keras.layers import Layer
from tensorflow.python.ops.init_ops import glorot_normal_initializer
from tensorflow.keras import backend as Kclass Gate(Layer):def __init__(self, expert_num, **kwargs):self.expert_num = expert_numself.gate = Nonesuper(Gate, self).__init__(**kwargs)def build(self, input_shape):self.gate = Dense(self.expert_num, activation='relu', kernel_initializer=glorot_normal_initializer)super(Gate, self).build(input_shape)def call(self, _inputs, **kwargs):weight = self.gate(_inputs)_output = tf.nn.softmax(weight)return _outputdef MMOE(num_field, hidden_units, num_expert, num_output):# 1.构造 Input => N x Fnum_samples = 10inputs = np.array([np.ones(shape=num_field) for i in range(num_samples)])print("Input Shape:", inputs.shape)# 2.构建专家 kernel [(filed * hidden) x expert]expert_kernels = np.random.random(size=(num_field, hidden_units, num_expert))print("Expert Shape:", expert_kernels.shape)# 3.获取 Expert 输出 => [N x F] * [F x Hidden x Expert] = N x Hidden x Expertoutputs_by_expert = tf.tensordot(inputs, expert_kernels, axes=1)print("Output By Expert:", outputs_by_expert.shape)# 4.构建 Gategate = [Gate(num_expert) for i in range(num_output)]# 5.获取每个任务的 Gate 输出权重 output x N x expertoutputs_by_gate = np.array([gate[i](inputs) for i in range(num_output)])print("Output By Gate:", outputs_by_gate.shape)# 6.获取最终输出 N x outputpart_input = []for output_by_gate in outputs_by_gate:# N x Expert => N x 1 x Expertexpand_output_by_gate = tf.expand_dims(output_by_gate, axis=1)# N x 1 x Expert => N x Hidden x Expertrepeat_gate_weight = K.repeat_elements(expand_output_by_gate, hidden_units, axis=1)# N x Hidden x Expertweighted_expert_output = tf.cast(outputs_by_expert, dtype='float32') * repeat_gate_weight# N x Hiddenweighted_expert_sum = tf.reduce_sum(weighted_expert_output, axis=2)part_input.append(weighted_expert_sum)print("Part Input:", np.array(part_input).shape)if __name__ == '__main__':num_field, hidden_units, num_expert, num_output = 4, 8, 3, 2MMOE(num_field, hidden_units, num_expert, num_output)

四.总结

MMoE 几个 Expert 最终输出维度相同,Gate 输出维度与 Expert 数量相同,通过分析 Gate 的输出概率可以看出不同 Expert 对不同 Output 的测出,也可以控制 loss_weights 显式的指定某个 Output 占据主导地位。


http://www.ppmy.cn/news/52043.html

相关文章

alibaba arthas的新人上手教程

背景 Arthas 是Alibaba开源的Java诊断工具。 github开源地址:GitHub - alibaba/arthas: Alibaba Java Diagnostic Tool Arthas/Alibaba Java诊断利器Arthas 上手教程 1.下载arthas,并测试运行demo curl -O https://arthas.aliyun.com/arthas-boot.j…

Docker的三种网络模式

Docker的三种网络模式 Docker支持三种网络模式:Host模式、Bridge模式和None模式。它们各自适用于不同的场景和需求: Host模式:将容器加入到主机的网络栈中,使容器直接使用主机的网络接口和IP地址。Host模式适用于需要容器与主机…

GLPT团队程序设计天梯赛 2023正式赛

2023.4.22 13&#xff1a;30-16&#xff1a;30 162分 团队1556分 L1-1 最好的文档 5 15990/21484(74.43%) 在一行中输出 Good code is its own best documentation.。 #include<bits/stdc.h> using namespace std; signed main(){cout<<"Good code is its …

国考省考行测:词句理解,词的对象指代,就近原则,主语一致法,语意语境分析上下文找出指代含义

国考省考行测&#xff1a;词句理解&#xff0c;词的对象指代&#xff0c;就近原则&#xff0c;主语一致法&#xff0c;语意语境分析上下文找出指代含义 2022找工作是学历、能力和运气的超强结合体! 公务员特招重点就是专业技能&#xff0c;附带行测和申论&#xff0c;而常规国…

接口测试入门必会知识总结(学习笔记)

目录 什么是接口&#xff1f; 内部接口 外部接口 接口的本质 什么是接口测试&#xff1f; 反向测试 为什么说接口测试如此重要&#xff1f; 越接近底层的 Bug&#xff0c;影响用户范围越广 目前流行的测试模型 接口测试的优越性 不同协议形式的测试 接口测试工作场景…

6个好用的企业管理软件推荐

企业管理软件的范围很广&#xff0c;财务、人力、客户关系管理、ERP、客户体验管理等等。国内来看&#xff0c;有些企业管理软件产品能覆盖企业数字化所有部分&#xff0c;在每个领域&#xff0c;也有很突出的头部厂商&#xff0c;产品功能和服务都大幅领先于竞对&#xff0c;我…

大数据题目测试(一)

目录 一、环境要求 二、提交结果要求 三、数据描述 四、功能要求 1.数据准备 2.使用 Spark&#xff0c;加载 HDFS 文件系统 meituan_waimai_meishi.csv 文件&#xff0c;并分别使用 RDD和 Spark SQL 完成以下分析&#xff08;不用考虑数据去重&#xff09;。 (1)配置环境…

[k8s] 八股文

最近在基于k8s做集群管理系统&#xff0c;心里有一些问题&#xff0c;在这里记录一下。 k8s中的服务发现&#xff1f; 这一点项目中有用到。任务模块通过gprc调用用户模块的服务时&#xff0c;使用了svc-user-v0.shuwd:8002&#xff0c;k8s是怎么将该流量转发到对应的pod上的…