写点什么

表格数据深度学习算法 NODE 技术解析

作者:qife
  • 2025-08-08
    福建
  • 本文字数:791 字

    阅读完需:约 3 分钟

深度学习在表格数据中的挑战

深度学习已在计算机视觉、自然语言处理等领域引发革命,但表格数据领域仍由经典机器学习算法(如梯度提升)主导。直觉上,神经网络作为通用近似器,理论上应能处理表格数据,但实际效果不及梯度提升树。这可能与决策树的归纳偏置更适合表格数据有关。

可微分决策树的突破

2015 年,Kontschieder 等人提出深度神经决策森林,通过将决策节点的严格二元路由松弛为概率化(使用 Sigmoid 函数),实现了决策树的可微分性。具体而言:


  • 叶节点:替换为 Softmax 层,输出类别分布。

  • 决策节点:使用 Sigmoid 函数计算样本向左/右路由的概率,通过路径概率乘积得到叶节点到达概率,最终预测为所有叶节点的加权平均。

神经遗忘决策树(NODE)

NODE 基于对称生长的遗忘树(Oblivious Tree),每层使用相同特征进行分裂。其核心创新包括:


  1. 特征选择:采用α-entmax替换 Softmax,实现稀疏特征选择(学习矩阵F)。

  2. 阈值松弛:将不可微的 Heaviside 函数替换为可微的双面α-entmax,并引入可学习的尺度参数b

  3. 响应张量:通过外积生成路径选择权重,最终输出为响应张量的加权和。

深度 NODE 架构

通过堆叠多个 NODE 层(带残差连接)构建深度模型:


  • 每层输入为前一层的输出与原始特征的拼接。

  • 最终预测为各层输出的平均。

实验与结果

在 Epsilon、Higgs 等 6 个数据集上,NODE 与 CatBoost、XGBoost 和全连接神经网络对比:


  • 默认参数:NODE(单层 2048 棵树,深度 6)表现优于传统方法。

  • 调参后:NODE 在多数任务中保持领先。

实现与工具

  • 官方实现:基于 PyTorch 的模块化代码库。

  • 集成库:支持在 PyTorch Tabular 中直接调用 NODE 及其他表格数据算法。

参考文献

  1. Kontschieder et al., Deep Neural Decision Forests (ICCV 2015).

  2. Peters et al., Sparse Sequence-to-Sequence Models (ACL 2019).

  3. Popov et al., Neural Oblivious Decision Ensembles (arXiv:1909.06312).更多精彩内容 请关注我的个人公众号 公众号(办公 AI 智能小助手)公众号二维码

  4. 办公AI智能小助手
用户头像

qife

关注

还未添加个人签名 2021-05-19 加入

还未添加个人简介

评论

发布
暂无评论
表格数据深度学习算法NODE技术解析_机器学习_qife_InfoQ写作社区