SwitchTab(AAAI, 2024)无源码
SwitchTab利用非对称编码器-解码器框架来解耦数据对中相互和独特的特征,并最终获得更加具有代表性的嵌入。这些嵌入推动了一个更加好的决策边界并提高了下游任务的表现效果。另外,预训练的独特的嵌入可以作为“即插即用”的特征在传统分类模型中使用。最后,SwitchTab具有可以通过解耦的相互和独特特征的可视化来产生可解释性的表示。
SwitchTab的核心在于:一个非对称的编码器-解码器架构,通过可使特征解耦合的定制映射器进行增广。这个过程从将每个样本编码为一个广义的嵌入,到将其映射为相互的和独有的嵌入。该模型的另一个优点就是它的多面性,其使用了自监督方法训练。这个适应性保证了SwitchTab在多个训练领域内表现优异,不管数据是否有监督。

要点速递

原作者无开源代码,模型准确性存疑,但是作者均为Amazon单位,模型最大的创新点在于特征解耦(类似PNPNet,可参考之前推文)。
PNPNet主要学习周期性数据中的周期性内容(傅里叶网络),非周期数据中的非周期内容(切比雪夫网络);SwitchNet则是对于所有数据均采用相同的提取策略提取共性与非共性的内容。

相关工作

特征解耦

在特征提取与潜表示学习中,基于自编码器的模型已被广泛使用,其具有在极少或没有监督的情况下洗得有效表示的能力。过去的工作主要集中于提取出一种表示,这种表示的每个维度可以表示一种语义学上的因子的改变而不受其他因素影响。近期的工作集中于捕捉不同变化因素之间的依赖关系和相互作用,以增强潜在表示的效果。进一步说,对比变分自编码器(cVAE)使用了对比分析策略,将潜在特征分为互相特征和独特特征并增强了独特特征。交换自编码器显式地将图片分为结构和纹理嵌入,并将该嵌入在图像生成中进行交换(这里的交换swapping指的是将图片A和图片B的特征进行交换获得新的图片)。近期的表格数据相关工作也证明了量化样本间关系的益处。关联自编码器(RAE)同时考虑数据特征和关系来产生更具鲁棒性的特征。
作者扩展cVAE与交换自编码器的想法到表格数据领域,假设两个样本数据通过潜在的样本关联具有共有的和独特的信息。独有的信息对于下游任务的决策边界极为重要,而交互信息对于数据重构非常重要。同时,特征解耦可以增强模型的解释性。

Method

ST模型方法
算法并不复杂,直接列出github热心群众avivnur的复现代码

1
2
3
4
def feature_corruption(x, corruption_ratio=0.3):
# We sample a mask of the features to be zeroed out
corruption_mask = torch.bernoulli(torch.full(x.shape, 1-corruption_ratio)).to(x.device)
return x * corruption_mask

特征污染。

1
2
3
4
5
6
7
8
9
10
11
12
class Encoder(nn.Module):
def __init__(self, feature_size, num_heads=2):
super(Encoder, self).__init__()
self.transformer_layers = nn.Sequential(
nn.TransformerEncoderLayer(d_model=feature_size, nhead=num_heads),
nn.TransformerEncoderLayer(d_model=feature_size, nhead=num_heads),
nn.TransformerEncoderLayer(d_model=feature_size, nhead=num_heads)
)

def forward(self, x):
# Since Transformer expects seq_length x batch x features, we assume x is already shaped correctly
return self.transformer_layers(x)

编码器,对应算法中的$f$。核心是三个多头注意力机制的堆叠。

1
2
3
4
5
6
7
8
class Projector(nn.Module):
def __init__(self, feature_size):
super(Projector, self).__init__()
self.linear = nn.Linear(feature_size, feature_size)
self.sigmoid = nn.Sigmoid()

def forward(self, x):
return self.sigmoid(self.linear(x))

映射器,对应算法中的$p_{m}$与$p_{s}$。全连接层

1
2
3
4
5
6
7
8
class Decoder(nn.Module):
def __init__(self, input_feature_size, output_feature_size):
super(Decoder, self).__init__()
self.linear = nn.Linear(input_feature_size, output_feature_size)
self.sigmoid = nn.Sigmoid()

def forward(self, x):
return self.sigmoid(self.linear(x))

解码器,对应$d$。全连接层

1
2
3
4
5
6
7
class Predictor(nn.Module):
def __init__(self, feature_size, num_classes):
super(Predictor, self).__init__()
self.linear = nn.Linear(feature_size, num_classes)

def forward(self, x):
return self.linear(x)

预测器,预训练中并不涉及,在fine-tuning中使用。全连接层。

Conclusion

Alcoholrithm/TabularS3L: A PyTorch Lightning-based library for self- and semi-supervised learning on tabular data.热心网友Alcoholrithm帮助复现了一些自监督与半监督学习的表格数据算法,并进行比较效果,SwitchTab在已有数据集的测试中并不卓越:
10% labeled samples

Model diabetes (acc) cmc (b-acc) abalone (mse)
XGBoost 0.7325 0.4763 5.5739
DAE 0.7208 0.4885 5.6168
VIME 0.7182 0.5087 5.6637
SubTab 0.7312 0.4930 7.2418
SCARF 0.7416 0.4710 5.8888
SwitchTab 0.7156 0.4886 5.9907
TabularBinning 0.7450 0.5061 5.9346

100% labeled samples

Model diabetes (acc) cmc (b-acc) abalone (mse)
XGBoost 0.7234 0.5291 4.8377
DAE 0.7390 0.5500 4.5758
VIME 0.7688 0.5477 4.5804
SubTab 0.7390 0.5432 6.3104
SCARF 0.7442 0.5521 4.4443
SwitchTab 0.7585 0.5411 4.7489
TabularBinning 0.7597 0.5486 4.5805

Cover image icon by Dewi Sari from Flaticon