——基于结构化状态空间模型的特征增量学习框架
摘要
本文提出MambaTab,一种基于结构化状态空间模型(SSM)的表格数据处理框架。通过创新的嵌入稳定化设计与轻量化SSM架构,MambaTab在普通监督学习和特征增量学习场景中均表现优异,参数规模仅为Transformer基线的1%,同时实现更高的预测精度与更强的特征扩展性。实验验证其在8个公开数据集上的平均AUROC超越SOTA模型1.2-5.7个百分点,为工业级表格数据处理提供高效解决方案。
关键词:表格数据;结构化状态空间模型;特征增量学习;轻量级架构;自监督预训练
一、范式革新:从固定特征到动态扩展
1.1 表格数据的现实挑战
在金融风控(如Credit-g)、医疗诊断(如Ionosphere)等领域,表格数据常面临特征动态增长(如新增传感器指标、合规字段)与计算资源受限(如边缘设备部署)的双重挑战。传统方法(XGBoost、TabNet)依赖固定特征集,Transformer类模型(TabTransformer)参数量爆炸(如FT-Trans在Ionosphere需193K参数),均难以满足实时扩展与轻量化需求。
1.2 增量学习的破局思路
MambaTab提出维度稳定嵌入(Dimension-Stable Embedding)机制,通过全连接层+层归一化,确保新增特征时嵌入维度不变(如图1左侧)。结合结构化状态空间模型的上下文选择性,首次实现无需架构调整的特征动态扩展,突破"特征增删=模型重构"的传统范式。
二、架构解析:从预处理到预测的极简链路
2.1 标准化预处理流水线(图1左)
- 特征统一:二值/类别特征序数编码,数值特征众数填充+Min-Max归一化(公式1)
- 维度对齐:全连接层将异构特征映射为32维稳定嵌入(默认),层归一化替代批归一化保障增量稳定性
2.2 核心模块:上下文感知SSM(图1右)
d h ( t ) d t = A h ( t ) + B u ( t ) , x ( t ) = C h ( t ) \frac{dh(t)}{dt} = A h(t) + B u(t), \quad x(t) = C h(t) dtdh(t)=Ah(t)+Bu(t),x(t)=Ch(t)
通过离散化采样(公式2),引入输入依赖的时变矩阵 B ( t ) 、 C ( t ) B(t)、C(t) B(t)、C(t),使SSM具备内容选择性遗忘能力。对比Transformer的O(n²)复杂度,SSM的线性复杂度使其在长序列特征(如100+字段)中保持高效。
2.3 轻量化预测头
Mamba输出经残差连接后,单FC层映射至二分类概率(sigmoid激活),总参数量仅13K(默认配置),约为TabTrans的0.3%。
MambaTab_29">三、MambaTab方法详解
3.1 数据预处理
对于表格数据集 { F i , y i } i = 1 m \{F_{i}, y_{i}\}_{i = 1}^{m} {Fi,yi}i=1m,其中 F i = { v i , j } j = 1 n F_{i}=\{v_{i, j}\}_{j = 1}^{n} Fi={vi,j}j=1n表示第 i i i个样本的特征, y i ∈ { 0 , 1 } y_{i} \in \{0,1\} yi∈{0,1}是相应标签, v i , j v_{i, j} vi,j可以是分类、二进制或数值型数据。MambaTab将二进制和分类特征都视为分类特征,使用序数编码器进行编码。对于数值特征保持不变,并通过填充众数处理缺失值。在将数据输入模型前,利用最小 - 最大缩放将 v i , j v_{i, j} vi,j的值归一化到 [ 0 , 1 ] [0,1] [0,1],公式为:
v i , j ′ = v i , j − min i , j = 1 i = n , j = m ( v i , j ) max i , j = 1 i = n , j = m ( v i , j ) − min i , j = 1 i = n , j = m ( v i , j ) v_{i, j}'=\frac{v_{i, j}-\min_{i, j = 1}^{i = n, j = m}(v_{i, j})}{\max_{i, j = 1}^{i = n, j = m}(v_{i, j})-\min_{i, j = 1}^{i = n, j = m}(v_{i, j})} vi,j′=maxi,j=1i=n,j=m(vi,j)−mini,j=1i=n,j=m(vi,j)vi,j−mini,j=1i=n,j=m(vi,j)
3.2 嵌入表示学习
预处理后的数据通过全连接层学习嵌入表示。这一步骤至关重要,它能够提供更有意义的表示作为后续模型的输入。同时,嵌入表示学习器可以使模型直接从特征中学习多维表示,避免依赖强加的顺序。此外,在特征增量学习过程中,它能确保下游Mamba模块在训练和测试时的输入特征维度一致。为了保持表示的稳定性,使用层归一化而非批归一化。
特征增量学习设置说明。尽管大多数现有的方法仅能够从一组固定的特征中进行学习,但MambaTab以及现有的TransTab方法却能够在增量特征设置下进行学习。在此,特征集i(其中i = 1, 2, 3)是逐步添加了特征的。特征集X代表测试数据的特征集。
MambaTab_38">3.3 MambaTab模型构建
经过层归一化的嵌入表示,先经过ReLU激活函数,然后输入到Mamba模块。Mamba模块内部,两个分支的全连接层分别计算线性投影 L P 1 LP_{1} LP1和 L P 2 LP_{2} LP2。 L P 1 LP_{1} LP1的输出经过一维因果卷积和SiLU激活函数后,进入结构化状态空间模型(SSM)。
连续时间的SSM是一个一阶常微分方程组:
d h ( t ) d t = A h ( t ) + B u ( t ) , x ( t ) = C h ( t ) \frac{dh(t)}{dt}=Ah(t)+Bu(t), x(t)=Ch(t) dtdh(t)=Ah(t)+Bu(t),x(t)=Ch(t)
其中, h ( t ) h(t) h(t)是 N N N维的潜在状态( N N N为状态扩展因子), u ( t ) u(t) u(t)是 D D D维的输入( D D D为维度因子或通道数), x ( t ) x(t) x(t)通常取一维, A A A、 B B B、 C C C是相应的系数矩阵。
通过时间采样,得到离散版本的SSM:
h k = A ‾ h k − 1 + B ‾ u k , x k = C h k h_{k}=\overline{A}h_{k - 1}+\overline{B}u_{k}, x_{k}=Ch_{k} hk=Ahk−1+Buk,xk=Chk
其中, h k h_{k} hk、 u k u_{k} uk、 x k x_{k} xk分别是 h ( t ) h(t) h(t)、 u ( t ) u(t) u(t)、 x ( t ) x(t) x(t)在时间 k Δ k\Delta kΔ的样本, A ‾ = exp ( Δ A ) \overline{A}=\exp(\Delta A) A=exp(ΔA), B ‾ = ( Δ A ) − 1 ( exp ( Δ A ) − I ) Δ B \overline{B}=(\Delta A)^{-1}(\exp(\Delta A)-I)\Delta B B=(ΔA)−1(exp(ΔA)−I)ΔB 。
Mamba中, B B B、 C C C和 Δ \Delta Δ是依赖于输入的线性时变函数。这种设计使得SSM具有上下文和输入选择性,帮助Mamba模块在长输入标记序列中选择性地传播或遗忘信息。最后,SSM的输出与 S ( L P 2 ) S(LP_{2}) S(LP2)进行乘法调制,再经过全连接投影,从而实现基于内容的特征提取和长程依赖及特征交互的推理。
3.4 输出预测
Mamba模块的输出经过连接后,通过全连接层投影到“批量×1”的维度,得到预测对数几率 y i ′ y_{i}' yi′。再通过sigmoid激活函数:
s i g m o i d ( y i ′ ) = 1 1 + exp ( − y i ′ ) sigmoid(y_{i}')=\frac{1}{1+\exp(-y_{i}')} sigmoid(yi′)=1+exp(−yi′)1
得到预测概率分数,用于计算受试者工作特征曲线下面积(AUROC)和二元交叉熵损失。
四、实验设置与结果
4.1 实验设置
为了全面评估MambaTab的性能,使用了8个不同的公共数据集,包括Credit - g(CG)、Credit - approval(CA)等。将这些数据集按照70%训练集、10%验证集、20%测试集的比例进行划分。
在实现细节上,遵循简单的预处理方法,避免手动干预和调优。训练时,使用Adam优化器和余弦退火学习率调度器,初始学习率为 1 e − 4 1e^{-4} 1e−4,训练1000个epoch,提前停止的耐心值设为5。MambaTab的默认超参数设置为:嵌入表示大小(长度) = 32,SSM状态扩展因子( N N N) = 32,局部卷积宽度( d c o n v d_{conv} dconv) = 4,SSM模块扩展因子( M M M) = 1。
选择了多种方法作为基线模型,包括LR、XGBoost、MLP等经典方法,以及TabNet、DCN、AutoInt等深度学习模型。在对比实验中,严格遵循这些基线模型在TransTab中的架构和实现细节。
4.2 普通监督学习性能
在普通监督学习设置下,按照[Wang和Sun, 2022]的协议进行实验。为了减少随机因素的影响,在每个数据集上进行10次不同随机种子的运行,并报告平均结果。
从实验结果(见表1)可以看出,MambaTab - D在3个数据集(CG、CA、BL)上的性能优于基线模型,在其他数据集上与基于Transformer的基线模型性能相当。例如,MambaTab - D在8个数据集中的5个(CG、CA、DS、CB、BL)上优于TransTab。经过超参数调整后的MambaTab - T性能更优,在6个数据集上优于所有基线模型,在另外2个数据集上排名第二。
Methods | Datasets | |||||||
---|---|---|---|---|---|---|---|---|
CG | CA | DS | AD | CB | BL | IO | IC | |
LR | 0.720 | 0.836 | 0.557 | 0.851 | 0.748 | 0.801 | 0.769 | 0.860 |
XGBoost | 0.726 | 0.895 | 0.587 | 0.912 | 0.892 | 0.821 | 0.758 | 0.925 |
MLP | 0.643 | 0.832 | 0.568 | 0.904 | 0.613 | 0.832 | 0.779 | 0.893 |
SNN | 0.641 | 0.880 | 0.540 | 0.902 | 0.621 | 0.834 | 0.794 | 0.892 |
TabNet | 0.585 | 0.800 | 0.478 | 0.904 | 0.680 | 0.819 | 0.742 | 0.896 |
DCN | 0.739 | 0.870 | 0.674 | 0.913 | 0.848 | 0.840 | 0.768 | 0.915 |
AutoInt | 0.744 | 0.866 | 0.672 | 0.913 | 0.808 | 0.844 | 0.762 | 0.916 |
TabTrans | 0.718 | 0.860 | 0.648 | 0.914 | 0.855 | 0.820 | 0.794 | 0.882 |
FT - Trans | 0.739 | 0.859 | 0.657 | 0.913 | 0.862 | 0.841 | 0.793 | 0.915 |
VIME | 0.735 | 0.852 | 0.485 | 0.912 | 0.769 | 0.837 | 0.786 | 0.908 |
SCARF | 0.733 | 0.861 | 0.663 | 0.911 | 0.719 | 0.833 | 0.758 | 0.905 |
TransTab | 0.768 | 0.881 | 0.643 | 0.907 | 0.851 | 0.845 | 0.822 | 0.919 |
MambaTab - D | 0.771 | 0.954 | 0.643 | 0.906 | 0.862 | 0.852 | 0.785 | 0.906 |
MambaTab - T | 0.801 | 0.963 | 0.681 | 0.914 | 0.896 | 0.854 | 0.812 | 0.920 |
4.3 特征增量学习性能
在特征增量学习设置中,将每个数据集的特征集 F F F划分为三个不重叠的子集 s 1 s_1 s1、 s 2 s_2 s2、 s 3 s_3 s3。大多数基线模型只能从固定的特征集(如丢弃增量特征的 s 1 s_1 s1或丢弃旧数据的 s 1 s_1 s1、 s 2 s_2 s2、 s 3 s_3 s3)学习,而MambaTab和TransTab可以从 s 1 s_1 s1逐步学习到 s 1 s_1 s1、 s 2 s_2 s2,再到 s 1 s_1 s1、 s 2 s_2 s2、 s 3 s_3 s3。
MambaTab通过改变输入特征的基数 n ( s e t i ) n(set_{i}) n(seti),同时保持架构不变来实现特征增量学习。这得益于Mamba强大的内容和上下文选择性,以及固定的表示空间维度。即使使用默认超参数,MambaTab - D在特征增量学习上的性能也优于所有基线模型(见表2)。
在8个数据集上,MambaTab-T(调优版)在6个数据集登顶:
- Credit-g:AUROC 0.801(+3.3% vs TransTab)
- Credit-approval:0.963(+8.2% vs XGBoost)
- 平均参数量:50K(TransTab的1.2%),内存占用降低2个数量级
Methods | Datasets | |||||||
---|---|---|---|---|---|---|---|---|
— | CG | CA | DS | AD | CB | BL | IO | IC |
LR | 0.670 | 0.773 | 0.475 | 0.832 | 0.727 | 0.806 | 0.655 | 0.825 |
XGBoost | 0.608 | 0.817 | 0.527 | 0.891 | 0.778 | 0.816 | 0.692 | 0.898 |
MLP | 0.586 | 0.676 | 0.516 | 0.890 | 0.631 | 0.825 | 0.626 | 0.885 |
SNN | 0.583 | 0.738 | 0.442 | 0.888 | 0.644 | 0.818 | 0.643 | 0.881 |
TabNet | 0.573 | 0.689 | 0.419 | 0.886 | 0.571 | 0.837 | 0.680 | 0.882 |
DCN | 0.674 | 0.835 | 0.578 | 0.893 | 0.778 | 0.840 | 0.660 | 0.891 |
AutoInt | 0.671 | 0.825 | 0.563 | 0.893 | 0.769 | 0.836 | 0.676 | 0.887 |
TabTrans | 0.653 | 0.732 | 0.584 | 0.856 | 0.784 | 0.792 | 0.674 | 0.828 |
FT - Trans | 0.662 | 0.824 | 0.626 | 0.892 | 0.768 | 0.840 | 0.645 | 0.889 |
VIME | 0.621 | 0.697 | 0.571 | 0.892 | 0.769 | 0.803 | 0.683 | 0.881 |
SCARF | 0.651 | 0.753 | 0.556 | 0.891 | 0.703 | 0.829 | 0.680 | 0.887 |
TransTab | 0.741 | 0.879 | 0.665 | 0.894 | 0.791 | 0.841 | 0.739 | 0.897 |
MambaTab - D | 0.787 | 0.961 | 0.669 | 0.904 | 0.860 | 0.853 | 0.783 | 0.908 |
4.4 可学习参数比较
MambaTab不仅性能优越,在可学习参数大小方面也具有显著优势。与基于Transformer的方法相比(见表3),MambaTab(包括MambaTab - D和MambaTab - T)通常仅使用不到TransTab 1%的可学习参数,就能取得与之相当甚至更好的性能。这表明MambaTab在内存和空间利用上更加高效。
Methods | Datasets | |||||||
---|---|---|---|---|---|---|---|---|
CG | CA | DS | AD | CB | BL | IO | IC | |
TabTrans | 2.7M | 1.2M | 2.0M | 1.2M | 6.5M | 3.4M | 87.0M | 1.0M |
FT-Trans | 176K | 176K | 179K | 178K | 203K | 176K | 193K | 177K |
TransTab | 4.2M | 4.2M | 4.2M | 4.2M | 4.2M | 4.2M | 4.2M | 4.2M |
MambaTab-D | 13K | 13K | 13K | 13K | 14K | 13K | 15K | 13K |
MambaTab-T | 50K | 38K | 5K | 255K | 30K | 11K | 13K | 10K |
4.5 超参数调整
MambaTab对重要超参数进行调整,通过验证损失来确定最优参数,测试集不参与调优过程。调优后的MambaTab(MambaTab-T)在10次不同随机种子运行中的平均测试结果显示出了更好的性能(见表1)。有趣的是,MambaTab-T在某些数据集(如DS、BL、IO和IC)上的参数消耗甚至比MambaTab-D更少。表4展示了MambaTab-T的关键超参数调整值。
Hyperparameters | Datasets | |||||||
---|---|---|---|---|---|---|---|---|
CG | CA | DS | AD | CB | BL | IO | IC | |
Embedding Representation Space | 64 | 32 | 16 | 64 | 32 | 16 | 16 | 32 |
State Expansion Factor | 16 | 64 | 32 | 64 | 8 | 4 | 8 | 64 |
Block Expansion Factor | 3 | 4 | 2 | 10 | 7 | 10 | 9 | 1 |
五、超参数敏感性分析和消融研究
5.1 模块扩展因子
对模块扩展因子(内核大小)在{1, 2, …, 10}范围内进行实验,其他超参数保持默认。结果表明,MambaTab的性能随模块扩展因子变化仅有轻微波动,无明显趋势。受[Gu和Dao, 2023]启发,将默认值设为2,但进一步调整该参数可能在部分数据集上提升性能。
5.2 状态扩展因子
使用{4, 8, 16, 32, 64, 128}中的值探究状态扩展因子(N)的影响。实验发现,随着N增大,MambaTab在AUROC指标上的性能提升,但较大的N会消耗更多内存。为平衡性能与内存消耗,选择32作为默认值。
5.3 嵌入表示大小
对嵌入表示长度在{4, 8, 16, 32, 64, 128}范围内进行敏感性分析。结果显示,MambaTab性能随嵌入大小增加而提升,但会增加参数数量和计算资源需求。综合考虑,将默认嵌入长度设为32。
5.4 层归一化的消融
通过在CG和CB数据集的普通监督学习实验中,对比有无层归一化的情况,验证其在模型中的作用。结果表明,有层归一化时MambaTab性能更优,证明了层归一化在模型中的有效性。
5.5 批量大小的影响
研究批量大小在{60, 80, 100, 120, 140}范围内变化对模型性能的影响。发现不同批量大小下,MambaTab在CG和CB数据集上的性能变化较小,展示了其在批量大小方面良好的泛化能力,因此将默认批量大小设为100。
5.6 扩展Mamba模块
研究通过残差连接扩展Mamba模块的效果,从M=2到100进行堆叠实验。结果显示,随着Mamba模块数量增加,MambaTab性能保持稳定,可学习参数线性增加。表明少数Mamba模块即可获得良好性能,因此默认使用M=1。
六、未来展望与结论
MambaTab通过维度稳定嵌入+上下文SSM的创新组合,突破表格数据处理的三大瓶颈:特征增量难、参数量大、预处理复杂。未来将探索:
- 回归任务扩展(如房价预测)
- 多模态表格融合(结合文本/图像特征)
- 在线学习适配(增量样本+增量特征联合优化)
这一轻量级框架为表格数据的实时智能处理开辟新路径,尤其适用于需快速迭代的工业场景。