【AAAI 2024】解锁深度表格学习(Deep Tabular Learning)的关键:算术特征交互
近日,阿里云人工智能平台 PAI 与浙江大学吴健、应豪超老师团队合作论文《Arithmetic Feature Interaction is Necessary for Deep Tabular Learning》正式在国际人工智能顶会 AAAI-2024 上发表。本项工作聚焦于深度表格学习中的一个核心问题:在处理结构化表格数据(tabular data)时,深度模型是否拥有有效的归纳偏差(inductive bias)。我们提出算术特征交互(arithmetic feature interaction)对深度表格学习是至关重要的假设,并通过创建合成数据集以及设计实现一种支持上述交互的 AMFormer 架构(一种修改的 Transformer 架构)来验证这一假设。实验结果表明,AMFormer 在合成数据集表现出显著更优的细粒度表格数据建模、训练样本效率和泛化能力,并在真实数据的对比上超过一众基准方法,成为深度表格学习新的 SOTA(state-of-the-art)模型。
背景
结构化表格数据——这些数据往往以表(Table)的形式存储于数据库或数仓中——作为一种在金融、市场营销、医学科学和推荐系统等多个领域广泛使用的重要数据格式,其分析一直是机器学习研究的热点。表格数据(图 1)通常同时包含数值型(numerical)特征和类目型(categorical)特征,并往往伴随有特征缺失、噪声、类别不平衡(class imblanance)等数据质量问题,且缺少时序性、局部性等有效的先验归纳偏差,极大地带来了分析上的挑战。传统的树集成模型(如,XGBoost、LightGBM、CatBoost)因在处理数据质量问题上的鲁棒性,依然是工业界实际建模的主流选择,但其效果很大程度依赖于特征工程产出的原始特征质量。
随着深度学习的流行,研究者试图引入深度学习端到端建模,从而减少在处理表格数据时对特征工程的依赖。相关的研究工作至少可以可以分成四大类:(1)在传统建模方法中叠加深度学习模块(通常是多层感知机 MLP),如 Wide&Deep、DeepFMs;(2)形状函数(shape function)采用深度学习建模的广义加性模型(generalized additive model),如 NAM、NBM、SIAN;(3)树结构启发的深度模型,如 NODE、Net-DNF;(4)基于 Transformer 架构的模型,如 AutoInt、DCAP、FT-Transformer。尽管如此,深度学习在表格数据上相比树模型的提升并不显著且持续,其有效性仍然存在疑问,表格数据因此被视为深度学习尚未征服的最后堡垒。
算术特征交互在深度表格学习的“必要性”
我们认为现有的深度表格学习方法效果不尽如人意的关键症结在于没有找到有效的建模归纳偏差,并进一步提出算术特征交互对深度表格学习是至关重要的假设。本节介绍我们通过创建一个合成数据集,并对比引入算数特征交互前后的模型效果,来验证该假设。
算法架构
本节介绍 AMFormer 架构(图 3),并重点介绍算数特征交互的引入。AMFormer 架构借鉴了经典的 Transformer 框架,并引入了 Arithmetic Block 来增强模型的算术特征交互能力。在 AMFormer 中,我们首先将原始特征转换为具有代表性的嵌入向量,对于数值特征,我们使用一个 1 输入 d 输出的线性层;对于类别特征,则使用一个 d 维的嵌入查询表。之后,这些初始嵌入通过 L 个顺序层进行处理,这些层增强了嵌入向量中的上下文和交互元素。每一层中的算术模块采用了并行的加法和乘法注意力机制,以刻意促进算术特征之间的交互。为了促进梯度流动和增强特征表示,我们保留了残差连接和前馈网络。最终,依据这些丰富的嵌入向量,AMFormer 使用分类或回归头部生成最终输出。
算术模块的关键组件包括并行注意力机制和提示标记。为了补偿需要算术特征交互的特征,我们在 AMFormer 中配置了并行注意力机制,这些机制负责提取有意义的加法和乘法交互候选者。这些交互候选随着会沿着候选维度被串联(concatenate)起来,并通过一个下采样的线性层进行融合,使得 AMFormer 的每一层都能有效捕捉算术特征交互,即特征上的四则算法运算。为了防止由特征冗余引起的过拟合并提升模型在超大规模特征数据集上的伸缩,我们放弃了原始 Transformer 架构中平方复杂度的自注意力机制,而是使用两组提示向量(prompt token vectors)作为加法和乘法查询。这种方法为 AMFormer 提供了有限的特征交互自由度,并且作为一个附带效果,优化了内存占用和训练效率。
以上是 AMFormer 在架构层引入的主要创新,关于模型更详细的实现细节可以参考原文以及我们的开源实现。
进一步实验结果
为了进一步展示 AMFormer 的效果,我们挑选了四个真实数据集进行实验。被挑选数据集覆盖了二分类、多分类以及回归任务,数据集统计如表 1 所示。
我们一共测试了包含传统树模型(XGBoost)、树架构深度学习方法(NODE)、高阶特征交互(DCN-V2、DCAP)以及 Transformer 派生架构(AutoInt、FT-Trans)在内的六个基准算法以及两个 AMFormer 实现(分别选择 AutoInt、FT-Trans 做基础架构,即 AMF-A 和 AMF-F),结果汇总在表 2 中。
在一系列对比实验中,AMFormer 表现更突出。结果显示,基于 MLP 的深度学习方法如 DCN-V2 在表格数据上的性能不尽如人意,而基于 Transformer 架构的模型显示出更大的潜力,但未能始终超过树模型 XGBoost。我们的 AMFormer 在四个不同的数据集上,与所有六个基准模型相比,表现一致更优:在分类任务中,它将 AutoInt 和 FT-transformer 的准确率或 AUC 提升至少 0.5%,最高达到 1.23%(EP)和 4.96%(CO);在回归任务中,它也显著减少了平均平方误差。相比其它深度表格学习方法,AMFormer 具有更好的鲁棒和稳定性,这使得在性能排序中 AMFormer 断层式优于其它基准算法,这些实验结果充分证明了 AMFormer 在深度表格学习中的必要性和优越性。
结论
本工作研究了深度模型在表格数据上的有效归纳偏置。我们提出,算术特征交互对于表格深度学习是必要的,并将这一理念融入 Transformer 架构中,创建了 AMFormer。我们在合成数据和真实世界数据上验证了 AMFormer 的有效性。合成数据的结果展示了其在精细表格数据建模、训练数据效率以及泛化方面的优越能力。此外,对真实世界数据的广泛实验进一步确认了其一致的有效性。因此,我们相信 AMFormer 为深度表格学习设定了强有力的归纳偏置。
进一步阅读:
● 论文标题:
Arithmetic Feature Interaction is Necessary for Deep Tabular Learning
● 论文作者:
程奕、胡仁君、应豪超、施兴、吴健、林伟
● 论文 PDF 链接:
https://arxiv.org/abs/2402.02334
● 代码链接:
评论