Attic速读(ICLR under review, 2025)
作者在TabPFN的基础上作出了一个改进,准确率相较于TabPFN有相对较大的提升。相较于TabPFN对每个样本给出一个token,Attic选择将样本中每个特征均给出一个token,这在一定程度上解决了TabPFN所出现的“特征顺序不变性”的问题,提升了表格深度学习模型表现。
要点速递
- Cell Tokens:即对于每个样本的每个特征均给出一个tokens,该想法并非作者独创,于SAINT(2021, NeurIPS Workshop)中便提及并使用,同时对于每个特征给出tokens(FT-Transformer, 2021, NeurIPS)和对于每个样本(TabPFN, 2023, ICLR)给出tokens的想法也均在之前的论文中提出并实践。但是并没有哪个方法格外优秀,这次作者的尝试证明使用单元格级别的tokens可以更加有效捕捉合成数据的相关性。
- 模型精度:作者尝试选用float16取代bfloat16,但是在预训练中失败了。float16代表更高精度和更小的范围,而bfloat16有精度较低而范围大。作者选择float16的原因很简单,高精度模型的测试确实准确率更高,但是却在实践过程中预训练失败。
- Flash Attention:作者使用Flash Attention以减少Attic使用的计算资源与内存占用。并且使用实验证明Flash Attention的使用并未影响最终准确率。Flash Attention是一种高效实现 Transformer 中 注意力机制(Attention) 的方法,专门设计用于解决标准 attention 在计算资源和内存占用上的瓶颈问题。它首次出现在 2022 年的论文中:“FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness” 。FlashAttention 主要有三个核心优化: Fused kernel(融合操作) 将多步操作(比如计算 softmax、dropout、乘法)合并为一个 kernel,减少内存访问。Block-wise attention 将 attention 分块处理(比如以 64 行一组),每次只在 GPU 中加载小块,从而避免中间结果写入/读取显存。Recomputation-free 通过 clever rearrangement 避免需要在反向传播中重新计算或缓存中间变量,节省显存。
Methodology
作者使用TabPFN数据集生成器和树数据集生成器来合成数据并预训练。微调则是采用真实世界数据进行。
Architecture
作者介绍了Attic,一个基于单元格tokens的ICL-Transformer。与TabPFN不同的核心点在于:Attic使用单元格tokens,其可以表示每个样本的每个特征;而TabPFN则使用样本tokens,即一个token代表所有的特征。
$x_{ij}\in \mathbb{R}$代表样本 i 的特征 j 。Attic将该样本嵌入为$h_{ij}\in \mathbb{R}^d$,该tokens经过L层,每层包括一个样本注意力层、一个特征注意力层、两层MLP,每层前均进行一次层归一化。MLP层有2个线性层,隐藏层节点数为5d,层内有GeLU激活函数,模型的最终层将$y_{Q}$独立并将其映射为类别,类别在所有的ICL-Transformer中均为10类。
样本注意力将样本维度作为序列并将特征维度作为批次维度,特征注意力则相反。样本注意力机制使用一个掩码来保证查询集中的样本不能看到另外的样本。这种注意力机制来自于TabPFN,用于避免测试集中的相互独立。
相较于TabPFN将$x_{i}\in \mathbb{R}^k$转化为$h_{i}\in \mathbb{R}^d$,Attic的内存消耗将增加为k倍,同时,TabPFN将标签转化为浮点向量,而Attic则将其作为一个单词嵌入为tokens。
下图简单展示模型的训练时间,可以作为参考:
Motivation
作者认为:该模型由于TabPFN的原因在于:TabPFN模型并未很好的实现特征先后顺序的无关,而Attic所使用的单元格tokens让模型对所有特征一视同仁,TabPFN则是给每个特征一个固定的位置,进而阻止其学习到特征所处位置与特征的关系。
Experiments
实验参数与结果参考Table.1
实验结果证明Attic准确率在
Mixed-Precision Training
作者使用float16与bfloat16,发现float16训练过程中极易发生模型的奔溃,所以并未使用精度更高的float16,而是使用bfloat16。
Regression
作者后续进行了回归任务的训练,Attic在回归任务上,准确率接近于分类,虽然并未超越Attic分类的效果,但是其依旧为最佳的回归模型。