TabPFNv2微调研究
原文章地址On Finetuning Tabular Foundation Models
要点速递
这篇文章主要讲述了对于TabPFNv2的微调策略,其中全量微调最为有效,同时分析了微调为何有效的原因:作者认为微调有效是因为微调过程中让模型更加容易注意到有效的同类样本,从而提升模型准确率。
TabPFNv2 微调策略
评价Protocol
我们在两个成熟的表格数据DL benchmarks上进行测试,我们只使用能够完全输入TabPFNv2的数据集。相较于之前的比较,我们的此次比较数据集更大且加入了更多强有力的深度学习模型。
对于所有的TabPFNv2微调,我们在验证集上使用具有10个学习率的logspace(5e-4, 5e-6)进行测试。对于其他的baseline,我们则使用超参数grids进行调参,并进行100次迭代。作者使用1024个样本进行梯度更新,其余样本作为输入。对于早停策略,作者在验证集上每隔10个梯度更新步进行一次表现效果的计算并在16个无提升的评估后停止微调。作者使用RMSE和分类准确率作为回归与分类的对应指标。作者使用了相对于MLP的相对提升指标。
Full finetuning
完全微调是最常用的方法,但是表格数据领域内并没有一个共识性的工作,虽然大部分工作都是基于TabPFNv1的
Parameter-efficient finetuning(PEFT)
PEFT是LLM领域内非常常见的适应策略,但是当前的表格数据模型大多都只有很小的规模,所以PEFT并不是非常重要,但是部分微调带来的推理偏倚和潜在的隐式正则化可能会对避免过拟合带来好处,作者提出如下可能:
Low Rank Adapters(LoRA)
Last layers
LayerNorn, Head and Embeddings
仅对特征和目标的线性嵌入层、MLP分类头和Afine层的归一化参数,这些参数只在全模型中占极小的一部分,但是很重要。
Numerical Feature Embeddings
该方法在DL中经常使用,但是并没有和TabPFNv2同时使用,作者尝试在TabPFNv2前面加入一个嵌入模块并进行微调。
微调和训练时间的记录
首先作者统计了不同微调策略的时间,全微调时间最短;更大的batchsize带来更短的微调时间和更好的表现。
最优微调策略
总体来说,微调后准确率均有所上升,但是,全微调和PEFT之间的差异是微小的。综合效率与准确率,全微调是最好的baseline。
值得注意的是,作者发现:引入非绑定(非共享)的线性或高级分段线性嵌入,对微调性能的提升作用有限。这表明,要么 TabPFNv2 模型已经学会了特征转换,因此无需更复杂的嵌入方式;要么就需要一种更复杂的嵌入方案(例如在预训练阶段引入)。这为未来的研究指明了一个有趣的探索方向。
作者证明:使用更大的数据集,微调带来的收益也会更大
剖析微调对TabPFNv2的影响
作者在分析过程中基于这样的一种直觉,预训练后的TabPFNv2 在进行预测时的工作方式类似于一些基于检索增强的模型,比如 ModernNCA(Ye et al. 2024)或 TabR(Gorishniy et al. 2024)。这样两种假设与之前的其他模态的上下文学习形成对比:SGD和复杂的电路算法。
特别的,作者推测:TabPFNv2的预测机制中,潜在的检索机制起主要作用,该检索机制通过训练集中样本间的注意力实现。 这样的推断源于不同方法间效果的比较,ModernNCA与TabPFNv2的提升呈现出极强的相关性(0.89 Pearson correlation)。这意味着,TabPFNv2或许本质上是一种高效的检索器。
基于这样的直觉,作者假设微调的作用就是对于TabPFNv2上下文样本权重的一种相关性信号的精制与细化。特别地,作者提出微调提升了query-representation和key-representations之间点积对于真实的标签相关性的反映能力。一个更加精确的相似度度量或许可以让Attention softmax更加高效地识别并提高高信息样本的权重。
为了证明这个假设,我们设计了一个实验来评价attention分数作为目标相关性的代理。对于原始和微调的模型,我们均提取出最后一层的对于每个上下文样本的attention分数,这些attention权重在后面用于计算对应训练样本的平均权重。这个实验的想法非常直接:如果注意力分数准确捕捉了训练样本和测试样本之间的相关性,那么这个加权平均应该接近预测结果。
有趣的是,微调的效果同时影响了attention的分布,微调让大部分数据集的上下文训练样本的attention分布发生熵降,这意味着对于这些样本,模型变得更加“专注”,专注于样本中较少的子集。这样的行为与我们的早期发现相一致,如果潜在的query-key点积提供了哪个训练集样本与测试样本最相似的清晰信号,这个模型就可以将更高的权重分配给那些更小且更相关的子集,避免了将注意力分配给其他无关数据。
随后作者分析了逐样本的attention entropy改变和逐样本的误差改变,发现大致上,attention entropy越大,误差越高
与表格深度学习SoTA的比较
作者比较了以下几个方法:
- MLP: 带有周期性数值特征嵌入的MLP
- MNCA: ModernNCA, 一个SoTA无参数表格深度学习模型
- TabM: 一个最近的SoTA表格深度学习模型
- XGBoost
未微调的TabPFNv2模型能力接近于$MLP^{PLR}$。微调后的则可以达到SoTA水平,以上测试均在1M数据的数据集上测试。
TabReD子集测试
在TabReD benchmark上的测试发现TabPFNv2不管是微调还是原版均不如TabM,这可能是由于TabReD所具有的时间漂移(Temperol Shifting)。
Limitation
数据集的选择
作者选择的数据集均为可以被TabPFNv2接受的数据集大小,所以对于那些大规模数据仍需要进一步验证
不同的特征和目标预处理
对于非基础表格数据(non-foundational tabular DL),我们并没有统一数据集的预处理策略,只用了作者最推荐的方法。