Paper Reading: Adaptive Neural Trees |
您所在的位置:网站首页 › 树形结构模型图片 › Paper Reading: Adaptive Neural Trees |
目录研究动机文章贡献自适应神经树模型拓扑与操作概率模型与推理优化实验结果模型性能消融实验可解释性细化阶段的影响自适应模型复杂度优点和创新点
Paper Reading 是从个人角度进行的一些总结分享,受到个人关注点的侧重和实力所限,可能有理解不到位的地方。具体的细节还需要以原文的内容为准,博客中的图表若未另外说明则均来自原文。
论文概况
详细
标题
《Adaptive Neural Trees》
作者
Ryutaro Tanno, Kai Arulkumaran, Daniel C. Alexander, Antonio Criminisi, Aditya Nori
发表会议
International Conference on Machine Learning (ICML)
发表年份
2019
会议等级
CCF-A
论文代码
https://github.com/rtanno21609/AdaptiveNeuralTrees
作者单位: University College London, UK Imperial College London, UK Microsoft Research, Cambridge, UK 研究动机神经网络和决策树都是强大且实用的机器学习模型,但是两种方法通常具有相互排斥的优点和局限性。神经网络是通过非线性变换的组合来学习数据的分层表示,该方法对特征工程的需求很小。同时神经网络是用随机优化器训练的,允许训练扩展到大型数据集。但是神经网络的架构通常需要手工设计,并根据任务或数据集进行固定,且有些任务下需要巨大的计算开销。决策树的特点是学习如何分割输入空间,令每个子集中都能用线性模型来解决问题。决策树的架构是基于训练数据进行优化的,在数据稀缺的情况下特别有优势。但是应用决策树时通常需要手工设计的数据特征,且损失函数是不可微的,所以限制了基于梯度下降的优化和复杂分割函数的使用。 文章贡献本文设计了自适应神经树(ANT)将 NN 和 DT 的优点结合起来,ANT 将树结构中的路由决策和根到叶的计算路径表示为 NN,从而实现了分层表示学习。ANT 以树形拓扑作为一个强结构先验,通过该结构令特征以分层方式共享和分离。同时提出了一种基于反向传播的训练算法,基于一系列决策来生长 ANT 的结构。总而言之,ANT同时具备了表示学习、架构学习、轻量级推理的能力。通过SARCOS、MNIST 和 CIFAR-10 数据集的实验,证明了本文方法具有较好的性能,具有多种良好的特性。
ANT(Adaptive Neural Trees)是一种树状结构模型,是基于三个可微操作的基本模块构成的,分别是 Routers、Transformers 和 Solvers,在图中分别用白色圆圈、黑色圆圈和灰色圆圈标出。 Routers:路由器,以样本作为输入,并确定将样本发送到左分支或右分支。例如可以将定义为一个小的 CNN,对该 CNN 的输出求平均后从伯努利分布中采样来决策进入哪个分支,左分支为 1,右分支为 0。 Transformers:变压器,ANT 的每条边都由一个或多个 Transformers 组成。每个 Transformers 都是一个非线性函数,用于将样本进行非线性变换后继续向下传递。例如它可以是一个卷积层 + ReLU,并且可以在一条边上堆叠多个 Transformers 来实现对特征的深度转换。 Solvers:求解器,它对转换后的输入数据进行决策,对于分类任务可以将其定义为线性分类器。
ANT 由树的根结点到叶节点构成的一个层级混合专家系统(hierarchical mixture of experts, HMEs)产生输出,每个 HMEs 都被定义为一个神经网络。ANT 的每个输入 x 根据 Routers 的决策来遍历树,并经历一系列 Transformers 的变换,直到到达一个叶节点,用对应的 Solvers 预测标签 y。假设树中有 L 个叶节点,参数为 Θ = (Θ, ψ, φ),则模型对样本输出的条件概率如下:
混合系数 π 量化了 x 被分配到叶节点 l 的概率,由从根节点到叶节点 l 的唯一路径 Pl上 所有 Routers 的决策概率的乘积给出,公式如下所示。式中的 l→j 为一个二值关系,且仅当叶子 l 位于节点 j 的左子树时为 true,xψj 为 x 在节点 j 处的特征表示。
ANT 的训练分为两个阶段,在生长阶段 ANT 基于局部优化学习模型架构,在细化阶段 ANT 基于全局优化进一步调整模型参数。
接着通过梯度下降最小化 NLL 来局部优化新增模块的参数,同时固定前一部分的参数。最后选择具有最低验证 NLL 的模型,如果它改善了之前的最低 NLL 则保留,这个过程逐级重复直到收敛。
本文使用 SARCOS 多元回归数据集、MNIST 和 CIFAR-10 数据集进行评估,同时进行消融实验,所有的模型都是在 PyTorch 中实现。 模型性能将 ANT 的性能与一系列 DT 和 NN 模型进行比较,ANT 在 SARCOS 上实现了最低的误差,并且在 MNIST 和 CIFAR-10 上表现良好。在 SARCOS 数据集中,全路径的 ANT-SARCOS 的 MSE 优于所有其他方法,单路径时 GBT 的性能略好于单个 ANT 同时需要更少的参数。
本文比较了在禁用 Transformers 或 Routers 的情况下,ANT 的不同结构的预测误差如下表所示。禁用 Transformers 时模型相当于 HMEs,禁用 Routers 相当于使用标准 CNN。在所有三个数据集上,任何一种消融都会导致不同模块配置的更高误差,证明 ANT 中特征学习和分层划分的结合是合理的。
ANT 算法的生长过程能够发现有用的层次结构,在没有对 Routers 施加任何正则化的情况下,学习到的层次结构通常会在 MNIST 和 CIFAR-10 数据集中显示某些类别的强专门化路径。
全局细化阶段改善了泛化误差,下图显示了 CIFAR-10 上各种 ANT 的泛化误差,垂直虚线表示模型进入细化阶段的代数。几种设置都获得更高的测试精度,并且全局优化使 Routers 的决策概率两极分化,导致一定的剪枝效果。
在 CIFAR-10 上设置大小为 50、250、500、2.5k、5k、25k 和完整训练集的数据集子集上训练 ANT、All-CNN 和线性分类器的三个变体,选择 All-CNN 作为基线,利用 5k 个样本的验证集的性能来选择最优模型。下图展示了实验的分类性能,随着数据集越来越小,ANT 和 All-CNN/线性分类器的测试精度之间的差距增加。
个人认为,本文有如下一些优点和创新点可供参考学习: 本文用树形结构来自适应地构建模型,决策路径构成了一个 NN,思路非常具有创新性; ANT 的结构包括子空间的划分、特征构造和决策模型的部分,模型结构设计有参考价值; 本文的方法在不同的问题上可通过不同的方式实现,模型的迁移性强。 |
今日新闻 |
推荐新闻 |
CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3 |