Non-Stationary Transformer(NeurIPS, 2022)
源代码
Transformer在时序预测上有着强大的能力,这是由于其具有的全局时序依赖建模能力(global-range modeling ability)。但是,其表现可能会在非平稳的真实世界数据中退化,这是因为真实世界的数据之间的联合分布是随时间的推移而发生变化,即非严平稳。过去的研究中主要使用平稳化的方法来减少非平稳数据的影响,但是平稳序列在一定程度上削弱了其固有的非平稳部分,这会使得构造的模型对于真实世界的预测效果变弱(真实世界数据大量存在非平稳)。这样的问题被称为 过平稳化(over-stationarization)。这也让Transformers的预测能力下降。为了解决这个问题,作者便提出了 Non-stationary Transformers 。改模型相较于Transformer模型减少了49.43%的MSE。

Related Work

时序预测的平稳化

为了将非平稳的数据转化为平稳数据,最经典的方法就是 ARIMA 中所使用的差分。这里简单讲述一下差分的做法:
ARIMA相较于AR(Auto Regression)MA(Moving Average)多出求差分的过程,这是因为ARMA并不能预测非平稳的数据,但是有人发现对于数据进行一次差分,即d_{i} = x_{i+1} - x_{i}可以将原本非平稳数据转化为平稳的序列。
而对于深度学习模型,Adaptive Norm使用每个片段的 z-score 进行归一化,DAIN 使用一个非线性神经网络使用训练集中的分布来适应性平稳化;RevIN 使用二阶段样本归一化转化为模型的输入与输出。
作者发现平稳化反而会破坏模型对于特定时间依赖性的建模能力,因此,不似之前的方法,作者使用了一个“去平稳注意力机制”来发掘数据内固有的非平稳部分。

Non-stationary Transfromers

该模型主要包含两个互补的部分:Series StationarizationDe-stationary,前者用于减少模型的非平稳,后者用于重新合并原序列的非平稳信息。

Series Stationarization

RevIN 采用可学习的仿射参数来对每个样本进行归一化与去归一化。这让每个时间序列满足一个相似的分布。作者发现不使用可学习参数也可以有效发挥作用。因此,作者提出一个直接但有效的模型将Transformer作为基模型包裹起来,这被称为 Series Stationarization 。 其包含两个主要的运算:1. 归一化模块来处理由不同均值与标准差导致的非平稳性;2. 去归一化模块用于将模型输出转化回原始的数值。

Normalization module

对于每个输入序列 $\mathbf{x} = [x_{1},x_{2},\dots,x_{S}]^T\in \mathbb{R}^{S\times C}$ 转化为 $\mathbf{x}’=[x_{1}’,x_{2}’,\dots,x_{s}’]^T\in \mathbb{R}^{S\times C}$
$$
\mu_{\mathbf{x}} = \frac{1}{S} \sum_{i=1}^{S} x_i, , \sigma_{\mathbf{x}}^2 = \frac{1}{S} \sum_{i=1}^{S} (x_i - \mu_{\mathbf{x}})^2, , x’_i = \frac{1}{\sigma_{\mathbf{x}}} \odot (x_i - \mu_{\mathbf{x}}),
$$
其中 $\mu_{\mathbf{x}},\sigma_{\mathbf{x}}\in \mathbb{R}^{C\times1}$, 其中 $\frac{1}{\sigma_{\mathbf{x}}}$ 代表逐元素的除法, $\odot$ 代表逐元素乘法。

De-normalization module

在基模型 $\mathcal{H}$ 预测的长度为 O 的未来值基础上,将模型输出 $\mathbf{y}’=[y_{1}’,y_{2}’,\dots,y_{O}’]^T\in \mathbb{R}^{O\times C}$ 转化为 $\hat{\mathbf{y}}=[\hat{y_{1}},\hat{y_{2}},\dots,\hat{y_{O}}]^T$
$$
\mathbf{y}’=\mathcal{H}(\mathbf{x}’), \hat{y}_{i}=\sigma_{\mathbf{x}}\odot y_{i}’+\mu_{\mathbf{x}}
$$

De-stationary Attention

这种常规的逆归一化策略并不能将结果恢复为原始序列,这是因为模型无法捕捉到与非平稳序列相关的时间依赖性。这也就意味着,这种过度平稳化发生于模型的内部。进一步地,这样地数据训练出的模型预测的序列也将是过度平稳化的数据。

Analysis of the plain model

$$
\text{Attn}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{Softmax}\left( \frac{\mathbf{Q}\mathbf{K}^{\top}}{\sqrt{d_k}} \right) \mathbf{V},
$$
正如上文所提到的,过度平稳化是由固有非平稳信息的丢失所导致,因此我们来逼近一下从原始非平稳序列中学习到的注意力。
其中 $\mathbf{Q}, \mathbf{K}, \mathbf{V}\in \mathbb{R}^{S\times d_{k}}$ 为长度为 S 的 queries, keys, values. $\text{Softmax}(\cdot)$ 按照行进行计算,为了简化计算,我们假设前馈神经网络层 $f$ 满足线性性质且 $f$ 是在每个时间点上单独计算,这也就意味着, queries可以写为 $\mathbf{Q}=[q_{1},q_{2},\dots ,q_{S}]^T,q_{i}=f(x_{i})$。既然在每个时间序列变量上使用归一化是一个传统操作,那么我们就可以假设每个变量均有相同的方差,也因此方差 $\sigma_{\mathbf{x}}\in \mathbb{R}^{C\times 1}$为一个标量。在归一化模块后,模型接收到的数据为 $\mathbf{x}’ = (\mathbf{x} - \mathbf{1}\mu_{\mathbf{x}}^{\top}) / \sigma_{\mathbf{x}}$。基于线性假设,我们可以证明最终模型中 $\mathbf{Q}’ = [f(x’_1), …, f(x’_S)]^\top = (\mathbf{Q} - \mathbf{1}\mu_{\mathbf{Q}}^\top) / \sigma_{\mathbf{x}}$,其中 $\mu_{\mathbf{Q}}\in \mathbb{R}^{d_{k}\times 1}$ 为 Q‘ 在时间维度上的均值,K’, V‘ 同样。在没有序列平稳化的时候,输入为 $\mathbf{Q}, \mathbf{K}, \mathbf{V}$。可以计算如下:
$$
\begin{align*}
\mathbf{Q}’\mathbf{K}’^{\top} &= \frac{1}{\sigma_{\mathbf{x}}^2} \left( \mathbf{QK}^{\top} - \mathbf{1}(\mu_{\mathbf{Q}}^{\top}\mathbf{K}^{\top}) - (\mathbf{Q}\mu_{\mathbf{K}})\mathbf{1}^{\top} + \mathbf{1}(\mu_{\mathbf{Q}}^{\top}\mu_{\mathbf{K}})\mathbf{1}^{\top} \right), \\
\text{Softmax}\left( \frac{\mathbf{QK}^{\top}}{\sqrt{d_k}} \right) &= \text{Softmax}\left( \frac{\sigma_{\mathbf{x}}^2 \mathbf{Q}’\mathbf{K}’^{\top} + \mathbf{1}(\mu_{\mathbf{Q}}^{\top}\mathbf{K}^{\top}) + (\mathbf{Q}\mu_{\mathbf{K}})\mathbf{1}^{\top} - \mathbf{1}(\mu_{\mathbf{Q}}^{\top}\mu_{\mathbf{K}})\mathbf{1}^{\top}}{\sqrt{d_k}} \right)
\end{align*}
$$
即采用我们去平稳化的Q’, K’, V’ 来表示未去平稳化的 Q, K, V 并计算attention。其中我们发现 $\mathbf{Q} \mu_{\mathbf{K}}\in \mathbb{R}^{S\times 1}, \mu_{\mathbf{Q}}^T\mu_{\mathbf{K}}\in \mathbb{R}$,这两个参数是不断在每行和每个元素上重复地计算,又考虑 Softmax 与行上相同的操作无关,所以可以省略:
$$
\text{Softmax} \left( \frac{\mathbf{Q} \mathbf{K}^{\top}}{\sqrt{d_k}} \right) = \text{Softmax} \left( \frac{\sigma_{\mathbf{x}}^2 , \mathbf{Q}’ \mathbf{K}’^{\top} + \mathbf{1} \mu_{\mathbf{Q}}^{\top} \mathbf{K}^{\top}}{\sqrt{d_k}} \right)
$$
这个公式揭示了计算中缺失的非平稳的信息 $\sigma_{\mathbf{x}}, \mu_{\mathbf{Q}}, \mathbf{K}$ 所在的位置。正因如此,作者提出以下 De-stationary Attention.

De-stationary Attention

从上式子可以看出,将未平稳化的信息带来计算的关键是逼近正标量 $\tau = \sigma_{\mathbf{x}}^2\in \mathbb{R}^+$ 和漂移向量 $\mathbf{\Delta}=\mathbf{K}\mu_{\mathbf{Q}}\in \mathbb{R}^{S\times 1}$,这也被称为去平稳化因子(_de-stationary factors_)。但是这个因子的计算来自于严格的线性特性,所以不能直接计算,作者决定使用原序列学习这个因子。学习采用的方法是一个简单但是高效的MLP。
$$
\begin{gather*}
\log \tau = \text{MLP}(\sigma_{\mathbf{x}}, \mathbf{x}), \boldsymbol{\Delta} = \text{MLP}(\mu_{\mathbf{x}}, \mathbf{x}), \\
\text{Attn}(\mathbf{Q}’, \mathbf{K}’, \mathbf{V}’, \tau, \boldsymbol{\Delta}) = \text{Softmax} \left( \frac{\tau , \mathbf{Q}’ \mathbf{K}’^{\top} + \mathbf{1} \boldsymbol{\Delta}^{\top}}{\sqrt{d_k}} \right) \mathbf{V}’,
\end{gather*}
$$
这里的去平稳化因子是所有De-stationary Attention层共享。
模型总体架构如下图所示:
architecture


Cover image icon by Dewi Sari from Flaticon