STUNT论文速读(ICLR, 2023)
STUNT: FEW-SHOT TABULAR LEARNING WITH SELF-GENERATED TASKS FROM UNLABELED TABLES(ICLR, 2023).有源代码。
在这偏文章中,作者提出了一种简单但是高效的小样本半监督表格学习框架:Self-generated Tasks from UNlabeled Tables(STUNT)。我们的主要想法是:通过将多个选中的列作为标签自己产生多种小样本任务。然后使用元学习模式通过构建的任务来学习泛化的知识。另外,通过使用STUNT从无标签数据中产生伪验证集,我们引入无监督的验证模式来进行超参数研究。
重点速递
这篇文章是一个相对细分的领域:表格数据中的半监督小样本学习(semi-supervised few-shot tabular data learning)。
ProtoNet:原型网络,通过计算一个类别的原型(可以是一个类别的均值)。然后计算原型到样本的距离来分类(k-means?)。
小样本学习架构:小样本学习一般记作N-way K-shot代表有N个类别,每个类别内有K个样本的情况,例如5-way 1-shot就代表有5个类,每个类有1个样本。同时,小样本学习batch与sample之间有一个特殊的层次episode,一个batch内有很多(1024)个episode,1 episode代表一次小样本学习的全过程:即训练,然后测试。与传统深度学习不同之处在于,episode内的训练集称为支持集$\mathcal{S}$(support set),测试集被称为查询集$\mathcal{Q}$(query set)。这意味着每个episode内部也有”梯度更新”(这里更应该称为参数更新,因为其实这篇文章中的元学习就不涉及梯度更新,而只是计算一个新的原型),而不仅仅是batch级别的梯度更新,所以小样本学习很重要的一个点就是理清episode和batch上的梯度(参数)更新究竟针对什么参数。
其实我们可以举出一个最简单的小样本学习架构:主干网络为深度神经网络,带有一个分类头(一般为MLP)。我们episode内部训练是只更新分类头,而batch上的训练则固定分类头,只改变深度神经网络。这一点很类似迁移学习:我们通过大量的任务让主干网络学会提取特征,然后使用分类头基于提取的特征在具体任务上进行分类。
这篇文章所使用的ProtoNet是episode内更新支持集的原型向量,batch训练中更新嵌入网络。
pseudo-validation:伪验证集,验证集用于和训练集一起使用,避免模型在训练集上过拟合。这里作者提出自行构建的验证集准确率和训练集准确率正相关、同分布且无交集,所以可以用于指示模型训练的拟合状态和超参数选择效果。
Pseudo-labels:伪标签,伪标签的构建依赖于k-means,但是k-means和原型网络的原理太类似了,所以需要对数据进行污染(从相同边缘分布内采样)后学习,避免过快学习。
Introduction
小样本学习(few-shot learning)在图像与语言领域均体现出了强大的能力,其可以使用无标注的数据学习一个可泛化且可迁移的表示。但是表格数据却并没有表现出类似的能力,自监督学习最佳模型VIME(2020)在小样本工作中的准确率还不如kNN分类器。作者认为这可能是因为训练的自监督任务和使用的小样本工作之间有较大差异,而这正是出于表格数据的异质性。
所以作者考虑,是否能够通过元学习(meta-learning)来减少与未出现的小样本任务之间的差距。作者的灵感来自于最近的无监督元学习研究。作者提出了一种简单但是高效的小样本半监督表格学习框架:Self-generated Tasks from UNlabeled Tables(STUNT)。作者的主要方法是通过将表格的列特征作为标签,从无标注的数据集中产生一系列任务,例如:血糖可以作为糖尿病的替代标签。特别的,我们使用K-means算法在随机选取的列特征的子集中进行聚类,并产生伪标签。另外,为了避免产生显然的任务(特征标签可以直接从输入特征进行推断),我们随机的将选中的列特征使用对应列的实际边缘分布进行采用并替代。然后我们使用了一个元学习模式,例如:原型神经网络(Prototypical Network)通过自生成任务来学习泛化的知识。
同时作者发现使用元学习的主要阻碍是缺失一个标注的验证集,元学习训练对于超参数非常敏感且非常容易过拟合。为了解决这个问题,作者提出一种基于STUNT的无监督的验证模式。我们发现提出的技术对于超参数搜索非常高效,伪验证集和测试集的准确率有很强的相关性。
Related Work
无监督元学习
元学习,例如:通过从任务分布中提取有效信息来学习如何学习,已经成为了一个著名的范例来让系统可以快速适应新的任务。最近,多项工作提到了无监督元学习模式来自生成标签用于减少数据标注的费用。为了产生这样的任务,CACTUs在自监督学习的表示上使用聚类算法,UMTRA将增广产生的数据作为相同伪类的方法(Ye也为无监督元学习提供了有效的策略:即使用足够的episodic sampling(episode是一个小样本学习的任务训练过程), 难样本混合的支持集(support是小样本学习中的训练集)和任务特化的映射头),Meta-GMVAE使用变分自编码器VAE进行聚类。但是,除去在图像领域的有效使用,当这些方法在表格数据中使用时就变得无效。这是因为这些方法预设了过于高质量的表示与数据增广,或者精妙的生成式模型(但其实这些模型对于表格数据来说太复杂了)。作者测试了多种方法,发现只有CACTUs有效。Meta-GMVAE甚至表现不如baseline。
模型架构
问题定义
这里介绍了小样本半监督学习的问题定义。首先,我们的目标是训练一个神经网络分类器$f_{\theta}=\mathcal{X} \to \mathcal{Y}$,该分类器参数为$\theta$,$\mathcal{X} \subseteq \mathbb{R}^d$,$\mathcal{Y}={ 0,1}^C$代表样本呢空间有C个类。对应的,假设我们有一个有标注的数据集$\mathcal{D}_{l}={ \mathbf{x}_{l,i},\mathbf{y}_{l,i} }_{i=1}^{N_{l}} \subseteq \mathcal{X} \times \mathcal{Y}$以及一个未标注的数据集$\mathcal{D}_{u}={\mathbf{x}_{u,i}}_{i=1}^{N_{u}}\subseteq\mathcal{X}$用于训练分类器$f_{\theta}$。注意:所有数据点都是从一个确定的数据生成分布$p(\mathbf{x},\mathbf{y})$中采样获得,样本间为独立同分布。同时我们也假设给定标签的数据集的基数很小,例如每个类只有一个样本的集合,但是我们有足够的无标注数据$N_{u}\gg N_{l}$
STUNT
为了获得一个好的分类器,我们使用无监督元学习方法:(i) 从$\mathcal{D}_{u}$中自生成多个任务${\mathcal{T}_{1},\mathcal{T}_{2},\dots}$,每个任务包括少量具有假标签的样本;(ii) 元学习$f_{\theta}$来使用不同任务间的泛化;(iii) 使用有真实标签的数据集$\mathcal{D}_{l}$调整分类器。
从无标签数据集中产生任务
表格数据的任何列均可以作为标签,这是因为表格数据本质上是异质性的。但是特别的一点是,有一个列和原始的标签相关性更加强,于是基于该列的任务将会和基于原始标签的任务非常相似,例如:我们需要从年龄和BMI来预测糖尿病的任务和通过BMI和年龄预测血糖的任务是类似的。在这个直觉的前提下,我们使用k-means在随机选定的数据子集上来创建伪标签,用于提高获取高度相关列的概率和多样性。
正式地说,为了产生一个单独的任务$\mathcal{T}_\text{STUNT}$,我们以掩码概率$p$(p满足均匀分布$U(r_{1}, r_{2})$)内采样。其中$r_{1}$, $r_{2}$为0到1之间的超参数,并产生随机二元掩码$\mathbf{m}:= [m_{1}, \dots,m_{d}]^\mathrm{T}\in{0,1}^d$,其中$\sum_{i}m_{i}=\lfloor dp \rfloor$。然后对于给定的无标签数据集,我们通过掩码所有来获取列,$\mathrm{sq}(\mathbf{x}\odot \mathbf{m})\in \mathbb{R}^{\lfloor dp \rfloor}$,其中sq代表移除为0的列。我们通过k-means获取伪标签$\tilde{\mathbf{y}}_{u,i}$
$$
\min _{\mathbf{C} \in \mathbb{R}^{\lfloor d p\rfloor \times k}} \frac{1}{N} \sum_{i=1}^{N} \min _{\tilde{\mathbf{y}}_{u, i} \in{0,1}^{k}}\left|\mathbf{s q}\left(\mathbf{x}_{u, i} \odot \mathbf{m}\right)-\mathbf{C} \tilde{\mathbf{y}}_{u, i}\right|_{2}^{2} \quad
\quad \tilde{\mathbf{y}}_{u, i}^{\top} \mathbf{1}_{k}=1
$$
以上就是一个k-means的优化目标,其中C代表质心矩阵。既然数据标签$\tilde{y}_{u}$来自于数据本身,那分类器可以轻松从给出的数据预测出标签,为了避免这样的情况,我们通过$\tilde{\mathbf{x}}_{u}:=\mathbf{m} \odot \hat{\mathbf{x}}_{u}+(1-\mathbf{m}) \odot \mathbf{x}_{u}$来干扰选定的列特征,其中每个元素的$\hat{\mathbf{x}}_{u}$是从每个列特征的真实边缘分布采样获得。最终产生的任务如下$\mathcal{T}_\text{STUNT}:={\tilde{\mathbf{x}}_{u,i}, \tilde{\mathbf{y}}_{u,i}}_{i=1}^{N_{u}}$
元学习
基于生成的任务,我们通过使用原型学习(ProtoNet)来进行元学习:在网络嵌入空间的前提下使用一个无参数的分类器。特别地,ProtoNet学习嵌入空间,该嵌入空间中,分类可以通过计算每个样本到每个类的原型距离来进行。原型向量可以设定为每个类别样本的嵌入向量之平均。为什么使用原型网络的原因可以总结为以下三点:1. 在原型网络中,训练和测试的类别数目可以不同,让我们可以搜索高效的质心数k,而不是使用一个固定的值;2. 这个方法是模型、数据无偏的,也就意味着其可以在表格领域内直接使用而不需要修改;3. 除了其简单,而且其也超越了许多先进的元学习模式。
对于一个给定的任务$\mathcal{T}_\text{STUNT}$,我们取两个非交的集合$\mathcal{S}$与$\mathcal{Q}$,这两个集合分别被用于构建分类器,并训练构建的分类器。我们在参数化嵌入$z_{\theta}: \mathcal{X}\to \mathbb{R}^D$的基础上构建原型网络分类器$f_{\theta}$,这里我们使用的原型向量为每个伪类$\mathbf{p}_{\tilde{c}}:=\frac{1}{|S_{\tilde{c}}|}\sum_{(\tilde{\mathbf{x}}_{u}, \tilde{\mathbf{y}}_{u})\in S_{\tilde{c}}}z_{\theta}(\tilde{\mathbf{x}}_{u})$,其中$S_{\tilde{c}}$包含了伪类$\tilde{c}$中的样本。
$$
f_{\theta}(y=\tilde{c} \mid \mathbf{x} ; \mathcal{S})=\frac{\exp \left(-\left|z_{\theta}(\mathbf{x})-\mathbf{p}_{\tilde{c}}\right|_{2}\right)}{\sum_{\tilde{c}^{\prime}} \exp \left(-\left|z_{\theta}(\mathbf{x})-\mathbf{p}_{\tilde{c}^{\prime}}\right|_{2}\right)}
$$
然后我们计算集合$\mathcal{Q}$上$f_{\theta}$交叉熵损失。然后我将在不同的任务上训练元学习网络来降低损失(就是在$\mathcal{S}$(support集合)上计算原型向量但是不更新$\theta$,在$\mathcal{Q}$(Query集合)上计算损失函数并更新)。
使用有标签数据进行适应
在使用自生成任务元学习参数$\theta$后,我们使用有标签数据集$\mathcal{D}_{l}$使用ProtoNet来构建小样本分类器。$f_{\theta}(\cdot;\mathcal{D}_{l})$,其中原型向量$\mathbf{p}_{c}$通过标签c的样本计算。
使用STUNT的伪验证集
这样一种提出的无监督学习缺失验证集来进行超参数的调整。为了解决这样的问题,我们引入了一个无监督验证模式:我们在无标签数据集上使用STUNT产生伪验证集。我们使用所有的列特征,而不是选择其中一部分数据,同时我们也使用所有的原数据,而不是使用干扰后的数据。
公式上说,我们从无标签数据中取出一部分数据$\mathcal{D}_{u}^{val}\subset \mathcal{D}_{u}$,然后通过运行k-means聚类产生伪标签,然后对于一个给定的验证工作$\mathcal{T}_\text{STUNT}^{val}={ \mathbf{x}_{u,i}^{val}, \mathbf{y}_{u,i}^{val} }_{i}$,我们取出两个无交的集合$\mathcal{S}^{val}$和$\mathcal{Q}^{val}$来评估ProtoNet分类器$f_{\theta}(\cdot;\mathcal{S}^{val})$在伪验证集中的表现。