表格数据深度学习算法 NODE 技术解析
深度学习在表格数据中的挑战
深度学习已在计算机视觉、自然语言处理等领域引发革命,但表格数据领域仍由经典机器学习算法(如梯度提升)主导。直觉上,神经网络作为通用近似器,理论上应能处理表格数据,但实际效果不及梯度提升树。这可能与决策树的归纳偏置更适合表格数据有关。
可微分决策树的突破
2015 年,Kontschieder 等人提出深度神经决策森林,通过将决策节点的严格二元路由松弛为概率化(使用 Sigmoid 函数),实现了决策树的可微分性。具体而言:
叶节点:替换为 Softmax 层,输出类别分布。
决策节点:使用 Sigmoid 函数计算样本向左/右路由的概率,通过路径概率乘积得到叶节点到达概率,最终预测为所有叶节点的加权平均。
神经遗忘决策树(NODE)
NODE 基于对称生长的遗忘树(Oblivious Tree),每层使用相同特征进行分裂。其核心创新包括:
特征选择:采用
α-entmax
替换 Softmax,实现稀疏特征选择(学习矩阵F
)。阈值松弛:将不可微的 Heaviside 函数替换为可微的
双面α-entmax
,并引入可学习的尺度参数b
。响应张量:通过外积生成路径选择权重,最终输出为响应张量的加权和。
深度 NODE 架构
通过堆叠多个 NODE 层(带残差连接)构建深度模型:
每层输入为前一层的输出与原始特征的拼接。
最终预测为各层输出的平均。
实验与结果
在 Epsilon、Higgs 等 6 个数据集上,NODE 与 CatBoost、XGBoost 和全连接神经网络对比:
默认参数:NODE(单层 2048 棵树,深度 6)表现优于传统方法。
调参后:NODE 在多数任务中保持领先。
实现与工具
官方实现:基于 PyTorch 的模块化代码库。
集成库:支持在 PyTorch Tabular 中直接调用 NODE 及其他表格数据算法。
参考文献
Kontschieder et al., Deep Neural Decision Forests (ICCV 2015).
Peters et al., Sparse Sequence-to-Sequence Models (ACL 2019).
Popov et al., Neural Oblivious Decision Ensembles (arXiv:1909.06312).更多精彩内容 请关注我的个人公众号 公众号(办公 AI 智能小助手)公众号二维码
- 办公AI智能小助手
评论