TabTransformer是基于Transformer架构构建的表格学习深度模型。它主要将类别变量转化为上下文嵌入,同时发现这样产生的嵌入对于缺失数据具有良好的鲁棒性与可解释性。最后作者提出一种半监督预训练方法。这篇文章发表于2020年,当时的SOTA还是MLP,深度学习还远远比不上树模型,TabTransformer有隐隐赶上树模型的趋势。

重点速递

TabTransformer主要在表格数据领域中使用了特殊的预训练流程。
MLM:对于数据进行遮掩,然后预测遮掩部分数据来进行预训练。
RTD:将部分数据赋予随机值,然后设计二分类分类器来识别随机数据进行预训练。

模型架构

TabTransformer包括有一个列嵌入层、N个Tranformer层和一个MLP层。
$(\boldsymbol{x},y)$表示一个特征-标签对,其中$\boldsymbol{x}={\boldsymbol{x}_{cat},\boldsymbol{x}_{cont}}$前者代表类别变量$\boldsymbol{x}_\text{cat}={x_{1},x_{2},\dots,x_{m}}$,后者代表数值变量(共c个)$\boldsymbol{x}_\text{cont}\in \mathbb{R}^{c}$。作者将所有的类别特征使用列嵌入嵌入到d维空间,$e_{\phi_{i}}(x_{i})\in \mathbb{R}^d$,对于所有特征,有$\boldsymbol{E}_{\phi}(\boldsymbol{x}_\text{cat})={e_{\phi_{1}}(x_{1}),\dots,e_{\phi_{m}}(x_{m})}$。然后$\boldsymbol{E}$将被输入到Transformer层内得到输出${\boldsymbol{h}_{1},\dots,\boldsymbol{h}_{m}}$其中每个$\boldsymbol{h}$均为d维。最终将Transformer的输出和连续变量堆叠起来,得到一个$d \times m +c$维的向量。然后将该向量输入到MLP内用于预测标签$y$并计算损失函数。

Transformer

架构与传统的Transformer一样

Column embedding

对于类别特征$i$。有一个嵌入查询表$\boldsymbol{e}_{\phi_{i}}(\cdot)$,对于第$i$个特征,有$d_{i}$个类别,那么查询表就有$(d_{i}+1)$个嵌入,多出的一个嵌入代表缺失值。对于嵌入类别$x_{i}=j\in[0,1,2,\dots,d_{i}]$,嵌入向量为$\boldsymbol{e}_{\phi_{i}}(j)=[\boldsymbol{c}_{\phi_{i}},\boldsymbol{w}_{\phi_{ij}}]$,其中$\boldsymbol{c}_{\phi_{i}}\in \mathbb{R}^l, \boldsymbol{w}_{\phi_{ij}}\in \mathbb{R}^{d-l}$。$\boldsymbol{c}_{\phi_{i}}$属于是列专属,$\boldsymbol{w}_{\phi_{ij}}$是单元格专属,用于区分同列的不同单元格。作者当然也使用了无列专属嵌入的方法和列专属与单元格专属直接相加的嵌入的方法,但是效果最好的还是直接堆叠(平均最好)

Pretraining the Embeddings

作者同时引入了无标签的样本进行预训练。
作者使用两种不同的预训练流程:MLM(masked language modeling)RTD(replaced token detection) 。给出一个输入$\boldsymbol{x}_\text{cat}={x_{1},x_{2},\dots,x_{m}}$,MLM随机选择$k%$的特征遮掩,Transformer训练来降低预测原特征的交叉熵。RTD随机将特征使用一个该特征的随机值替代,这里是预测某个特征是否为替代特征,并计算损失函数。RTD在原始论文中是用一个子集来代替,但这是因为NLP中如果直接使用随机值,因为token有成千上万,一个替代的正态分布的特征太容易被发现了。而本文中,特征数量有限制且每列都有一个二分类器,所以直接用该特征的随机一个token。

Experiments

作者在这篇文章的实验环节有许多有趣的创新点:

半监督学习

已有数据集均有标签,所以为了构造无监督的数据,作者选择了数据中部分数据赋予标签,剩余数据作为无监督的数据进行测试。
实验证明:直接使用Transformer效果优于MLP,使用预训练的TabTransformer效果更优于MLP

嵌入表示的效果测试

作者为了研究不同层中的上下文嵌入的效果,使用t-SNE对于每个样本进行降维,发现使用TabTransformer时语义学上相似的类会更加接近,而在MLP中并没有出现这种聚集现象。随后作者直接对每层的输出的嵌入使用线性回归测试预测效果,这个测试的目的就是:使用最简单的分类方法查看嵌入的质量。

鲁棒性测试

作者污染、删除数据,然后查看不同模型对于污染数据的预测效果。

不同学习方法的比较

对于有监督学习,比较对象有:MLP, GBDT, Sparse MLP, Logistic Regression, TabNet, VIB
对于无监督学习,比较对象有:ER MLP(Entropy Regularization combined with MLP), PL MLP\TabTransformer\GBDT(Pseudo Labeling), MLP(DAE)


Cover image icon by Dewi Sari from Flaticon