模型速览

TabICL: In-Context Learning for Tabular Data是对于TabPFN的一次拓展,其名字取自于In-Context Learning(ICL),主要是使用ICL对于数据量大于10K的表格数据进行预测。其可以处理500K的数据量的同时,速度是TabPFNv2的十倍
为了适配任意大小的表格数据,TabICL单元格(cell) 作为基本处理单位: 每一列被看作是一组单元格值的集合,用于捕捉该特征的分布与语义; 每一行则由多个相互依赖的特征值构成。TabICL 采用两阶段架构,以实现高效的表格数据的 In-Context Learning(ICL,上下文学习)。

第一阶段:

将每一行(不包含目标标签)编码为稠密向量(dense vector embeddings),
每一个嵌入向量都被设计为能捕捉整个表格的信息。
该阶段的本质是压缩列维度,从而显著降低后续 ICL 的计算复杂度与内存开销。

第二阶段:

将这些紧凑但信息丰富的嵌入向量与其对应的标签结合,执行 ICL。
因此,TabICL 的核心在于第一阶段的嵌入策略,它需要将行数据转换为具有语义丰富性的向量表示。

要点速递

  1. Distribution-aware Column-wise Embedding(具分布感知能力的逐列嵌入):通过学习变量的数据分布作为列的嵌入,是一种数据特征挖掘的新思路。
  2. Hypernetwork(超网络):通过一个网络优化另一个网络的参数。可以与学习数据特征的算法结合,用以优化网络参数。
  3. ISAB(Inducing self-attention block):多头注意力机制的加速策略,可以有效降低注意力机制的时间复杂度(O(n^2) ->O(nk))
  4. RoPE(rotary positional embedding):在transformer中引入位置特征的策略,作者在这里引入用于区分分布相似的不同列。此方法在Attention is All You Need中亦有提及。
  5. Tree-based generation:基于树模型的数据合成。创建一个XGBoost,将父节点作为输入,高斯噪音作为标签训练,这样可以构造一种非常复杂的非线性相关。用于创造非线性相关的数据。
  6. curriculum learning for large-scale pretraining(大规模预训练的课程学习):LLMs中常用的一种预训练方法,仿照人类学习的过程,先学习简单的内容,再学习困难的内容,保证模型收敛。这里作者使用它作为预训练的策略。
  7. Hierarchical Class-extension Strategy:多类别分类的一种分类策略。通过创建一个树,然后在每个中间节点创造一个分类头(作者在文中使用的是2层MLP),这样就将一个大问题分解为多个小的分类问题。
  8. Permutation invariant的解决策略:列的先后对于结果无影响,也就是说列如果交换,预测结果不应该发生改变,但是该论文和TabPFN均不能在架构(均为有位置编码的Transformer架构)上实现这个目标,所以他们采用使用多种不同的排列循序训练并使用集成学习集成不同排序下的预测结果。

模型架构

Distribution-aware Column-wise Embedding(具分布感知能力的逐列嵌入)

对于标量单元格$c_{j} \in \mathbb{R}$,我们将其嵌入为d维的向量,不像以往的嵌入方式(对每列分配一个独特的嵌入方式),我们对所有列共享嵌入方式$\mathrm{TF_{col}}$:
$$
\begin{aligned}
W, B&=\mathrm{TF_{col}}(c_{j})&\in \mathbb{R}^{n\times d}\\
e_{j}&=W\odot c_{j}+B &\in \mathbb{R}^{n\times d}
\end{aligned}
$$
这是一种全新的思路,因为这里的bias和weight是通过transformer来计算的,作者将其称为一个类似”超网络”(Hypernetwork)的思路,就是可以生成其他网络参数的网络。而且,TF网络可以生成每个cell(也就是每个单元格)独特的权重,虽然所有列共享嵌入方式,但其实嵌入的权重是每个单元都不同。为了保证数据的先后对于结果无影响(作者将其描述为permutation-invariant),也就是我们这个TF只用于学习数据的分布,而不学习数据的顺序,$\mathrm{TF_{col}}$内部架构如下:
$$
\begin{aligned}
U &= \mathrm{Lin}(c) &\in \mathbb{R}^{n\times d}\\
M &= \mathrm{MAB_{1}}(V_{I}, U_{train}, U_{train}) &\in \mathbb{R}^{k\times d}\\
V &= \mathrm{MAB_{2}}(U, M, M) &\in \mathbb{R}^{n\times d}\\
W,B &=\mathrm{Lin}(V) &\in \mathbb{R}^{n\times d}
\end{aligned}
$$
MAB为多头注意力模块,Lin为全连接层,两个多头注意力模块构成的是ISAB(induced self-attention block)
ISAB是一种用于处理集合结构数据的模块,用于提升自注意力机制在大规模输入中的效率和建模能力,其中k远小于n,这样就可以减少运算量(可从O(n^2) ->O(nk)),M称为诱导点(inducing points)
参数选择:d=128, k=128, 注意力头数=4,ISAB=3
作者为了进一步了解学习的嵌入方式,将最后一层ISAB的M的对第一个维度求和,得到每一个列的单个向量,然后使用PCA的方法降维,就会发现有相同峰度和偏度的数据会偏向于聚在一起,这意味着TF可能通过嵌入反应了数据的特征。该方法与通过语义嵌入、通过特征识别向量嵌入不同

Context-aware Row-wise Interaction(具内容感知能力的逐行融合)

在获取所有特征嵌入后$E =[e_{1},\dots,e_{m}]\in \mathbb{R}^{n\times m\times d}$,一个三层8头的transformer$\mathrm{TF_{row}}$对E进行特征间交互。为了将嵌入合成为一个向量,四个可学习的[CLS]被添加到E的每行开头,[CLS]将会在最终输出的时候堆叠在一起,这样的4个[CLS]大小为$4 \times 128=512$。[CLS]是对于特定样本的一种更加细节且富有信息量的表示。
这个模块主要是为了防止表示崩塌(Representation collapse)的发生,表示崩塌发生于所有的列内数据分布相同或相似,$\mathrm{TF_{col}}$失去了对列进行区分的作用。作者使用了一种RoPE(rotary positional embedding)的方式来打破相似分布的特征的对称性(TabPFNv2中使用的是随机的特征识别向量并按组来表示特征)。RoPE主要用于自然语言处理中通过旋转query和key向量来表示位置信息,旋转角度如下确定:p:在序列中的位置;i:维度索引;$\theta_{i}=\frac{p}{\mathrm{base}^{2i/d}}$;d:嵌入维度;base:频率缩放参数,transformer中使用的是10000,该文章使用的是100000,频率缩放参数越大代表容许的长度越长
此外,为了保证列与列之间可交换顺序,TabICL和TabPFN使用了相同策略:对于列进行重新排列组合并进行预测后集成学习。

Dataset-wise In-Context Learning(逐数据集的语境学习)

在将所有样本转化为嵌入$H \in \mathbb{R}^{n\times 4d}$后,训练标签通过独热编码的方式被映射到与H相同的空间内。X与y相加得到最终的训练嵌入$H_{train}$。然后我们通过一个12层4头的Transformer$\mathrm{TF_{icl}}$处理$H_{train}$与$H_{test}$。其中$H_{train}$中的嵌入可以参考其他的,但是$H_{test}$中的只能参考$H_{train}$中以避免发生数据泄露。最后,一个两层的MLP将$H_{test}$的输出转化为类别概率。

时间复杂度分析

$\mathrm{TF_{col}}$复杂度为$\mathcal{O}(nkm)$(作者省去了d)
$\mathrm{TF_{row}}$复杂度为$\mathcal{O}(m^2n)$(也省略了d)
$\mathrm{TF_{icl}}$复杂度为$\mathcal{O}(n^2)$(也省略了d)
与TabPFN的时间复杂度$\mathcal{O}(m^2n+n^2m)$相比TabICL($\mathcal{O}(m^2n+n^2)$)对于大n和适中的m训练时间更加短。

预训练与推理

利用合成数据的训练

合成数据的训练过程依旧类似于TabPFN,但是作者额外加入了基于树模型的SCMs(structural causal models)构造过程中子代与父代间更多的函数$f$(在SCMs模型内部,变量被描述为节点,子节点由父节点计算而来$c = f(Pa(c))+\epsilon$,$Pa(c)$是其父节点的值)

Tree-based generation

$f$使用XGBoost回归模型定义。XGBoost输入父节点的值,用高斯噪声作为目标训练,获取的预测结果接下来作为子节点的值。作者使用了70%的SCMs和30%tree-based SCMs。

Diversifying activation function

作者在已有的4种激活函数的基础上加入了另外15种激活函数来增强数据中非线性相关的多样性。

大规模预训练的课程学习(curriculum learning for large-scale pretraining)

课程学习(Curriculum Learning, CL)最早由Bengio等人在2009年提出,它模仿人类学习的方式。训练过程中不是将所有训练数据随机喂给模型,而是按照某种 “难度”顺序 ,从“容易”样本逐步过渡到“困难”样本。这种学习方法主要用于LLMs的预训练,作者在这里通过逐渐增加合成数据集的大小,同时调整用于梯度累积的微批量大小($N_{\mathcal{B}}$: micro batch size)来适应内存限制,训练可分为3个阶段:

  1. $N_{\mathcal{B}}$=4,batch size = 1024, 160k步训练
  2. $N_{\mathcal{B}}$=1,batch size $\sim$ log-uniform$[1\mathrm{K}, 40\mathrm{K}]$ , 2k步训练
  3. $N_{\mathcal{B}}$=1,batch size $\sim$ log-uniform$[40\mathrm{K}, 60\mathrm{K}]$ , 50步训练,只更新$\mathrm{TF}_{icl}$

Hierarchical Class-extension Strategy

一种多类别分类任务的解决方法。
当类别大于10种时,就会递归地构造一个多层分类树,每个中间节点可以区分10个类别(中间节点的分类是进行2层的MLP进行分类),所以当要完成k分类任务时,需要的树深度为r=$\lceil \log_{10}k \rceil$。这样在推理过程中,只需要计算从树根到类别对应叶节点的概率的累乘,就可以得到预测的概率。
这样一个树的构建发生于dataset-wise ICL。
所有的子任务都是基于相同的行嵌入H和$\mathrm{TF_{icl}}$。这些共享的参数提高了TabICL的效率。


Cover image icon by Dewi Sari from Flaticon