[论文阅读] Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention

之前的工作证明了多头自注意力只要有足够的注意力头数就可以表示任意的卷积层。但是,本文反向表明,用自回归目标训练的自注意力层可以被看作是一个RNN,可以显著加快自回归transformer模型的推理时间。

Transformer

\(x\in\mathbb{R}^{N\times F}\),\(N\)\(F\)维的特征向量。Transformer即一个函数\(T:\mathbb{R}^{N\times F}\rightarrow\mathbb{R}^{N\times F}\),由\(L\)个transformer层\(T_{1}(\cdot),\dots,T_{L}(\cdot)\)组成: \[ T_{l}(x)=f_{l}(A_{l}(x)+x). \] \(A_l(\cdot)\)代表自注意力函数。输入序列\(x\)由三个矩阵\(W_Q\in\mathbb{R}^{F\times D},W_K\in\mathbb{R}^{F\times D},W_v\in\mathbb{R}^{F\times M}\)映射到\(Q,K,V\)\(A_l(x)=V^\prime\) \[ Q=xW_Q,\\ K=xW_K,\\ V=xW_V,\\ A_l(x)=V^\prime={\rm softmax}(\frac{QK^T}{\sqrt{D}})V. \] softmax函数按行应用于\(QK^T\),\(Q,K,V\)分别表示queries、keys和values。

式2表示了一种特定形式的注意力,称为softmax注意力,其中相似性是由\(Q\)\(K\)的点积的指数表示的。给定一个下标\(i\),返回一个矩阵的第\(i\)行作为一个向量,对于任意相似性函数,可以写出一个广义的注意力方程: \[ V_i^\prime=\frac{\sum_{j=1}^N{\rm sim}(Q_i,K_j)V_j}{\sum_{j=1}^N{\rm sim}(Q_i,K_j)} \] 将式3中的相似性函数\({\rm sim}(q,k)\)替代为\(\exp(\frac{q^Tk}{\sqrt{D}})\),则与式2等价。

Linearized Attention

式2中注意力的定义具有一般性,可以用来定义一些其他的注意力,如多项式注意力和RBF核注意力。为了使式3定义一个注意力函数,需要对\(sim(\cdot)\)施加一个非负的约束。\(k(x,y):\mathbb{R}^{2\times F}\rightarrow\mathbb{R}_{+}\)

给定一个特征表示核函数\(\phi(x)\),则可以将式2重写为: \[ V_i^\prime=\frac{\sum_{j=1}^N\phi(Q_i)^T\phi(K_j)V_j}{\sum_{j=1}^N\phi(Q_i)^T\phi(K_j)} \] 根据矩阵乘法的结合律,进一步简化: \[ V_i^\prime=\frac{\phi(Q_i)^T\sum_{j=1}^N\phi(K_j)V_j^T}{\phi(Q_i)^T\sum_{j=1}^N\phi(K_j)} \] 当分子写成向量形式时,式5可以更简化 \[ (\phi(Q)\phi(K)^T)V=\phi(Q)(\phi(K)^TV) \] 特征映射\(\phi(\cdot)\)逐行应用于矩阵\(Q,K\)

对于式2,softmax注意力的计算复杂度式\(\mathcal{O}(N^2)\)的,\(N\)表示序列长度。空间复杂度也是相同的,因为要保存完整的注意力矩阵来计算\(Q,K,V\)的梯度。

对于式5,linear transformer的时间复杂度、空间复杂度都是\(\mathcal{O}(N)\)的,因为我们可以一次计算出\(\sum_{j=1}^N\phi(K_j)V_j^T\)\(\sum_{j=1}^N\phi(K_j)\)并且在每个查询中重复使用。

对于softmax注意力,乘法和加法的总复杂度为\(\mathcal{O}(N^2\max(D,M))\)\(D\)\(Q,K\)的维度,\(M\)\(V\)的维度。

对于线性注意力,首先计算维度为\(C\)的特征图,然后计算新值的加法和乘法的复杂度为\(\mathcal{O}(NCM)\)

先前的分析中并没有考虑到核函数和特征函数的选择。与指数核对应的特征函数是无穷维的,这导致不能精确地线性化softmax注意力。而另一方面,多项式核具有精确的、有限维的特征映射,并且已被证明与指数核或RBF核同样有效。计算一个2次线性化多项式transformer的复杂度为\(\mathcal{O}(ND^2M)\)

对于小序列,使用一个特征映射得到正的相似度 \[ \phi(x)={\rm elu}(x)+1 \] 相较于\({\rm relu}(\cdot)\)\({\rm elu}(\cdot)\)可以避免x为负时将梯度设置为0。这样的特征映射产生的注意力计算复杂度为\(\mathcal{O}(NDM)\)

Causal Masking

利用transformer框架可以通过掩盖注意力高效地训练自回归模型,使得第\(i\)个位置只能受到第\(j\)个位置的影响,当且仅当\(j\leq i\),即一个位置不能受到后续位置的影响。由此将式3改写为: \[ V_i^\prime=\frac{\sum_{j=1}^i{\rm sim}(Q_i,K_j)V_j}{\sum_{j=1}^i{\rm sim}(Q_i,K_j)}. \] 又由之前的推理,将掩蔽注意力线性化如下: \[ V_i^\prime=\frac{\phi(Q_i)^T\sum_{j=1}^i\phi(K_j)V_j^T}{\phi(Q_i)^T\sum_{j=1}^i\phi(K_j)}. \] 引入\(S_i,Z_i\)如下: \[ S_i=\sum_{j=1}^i\phi(K_j)V_j^T, \]

\[ Z_i=\sum_{j=1}^i\phi(K_j), \]

将式9仅一步简化: \[ V_i^\prime=\frac{\phi(Q_i)^TS_i}{\phi(Q_I)^TZ_i} \] \(S_i,Z_i\)是可以由\(S_{i-1},Z_{i-1}\)连续计算得到的,因而使得带有因果掩码的linear transformer的计算复杂度与序列长度称线性关系。

gradient computation

在任何深度学习框架中,式12的朴素实现都需要存储所有的中间值\(S_i\)以计算梯度,这使得内存的消耗量最大增加\(\max(D,M)\)倍,影响了对长序列或者深度模型的适用性。为解决这一问题,导出式9中的分子的梯度作为累加和。这使得我们可以在线性时间和固定的内存空间同时计算causal linear attention的前向和后向传播。

给定分子\(\bar{V_i}\)和标量损失函数对于分子\(\bar{V_i}\)的梯度\(\nabla_{\bar{V_i}}\mathcal{L}\),导出\(\nabla_{\phi(Q_i)\mathcal{L}},\nabla_{\phi(K_i)}\mathcal{L},\nabla_{V_i}\mathcal{L}\)如下: \[ \nabla_{\phi(Q_i)\mathcal{L}}=\nabla_{\bar{V_i}}\mathcal{L}{\Bigg(}\sum_{j=1}^i\phi(K_j)V_j^T{\Bigg)}^T, \]

\[ \nabla_{\phi(K_i)\mathcal{L}}={\Bigg(}\sum_{j=1}^N\phi(Q_j)\Big(\nabla_{\bar{V_i}}\mathcal{L}\Big)^T{\Bigg)}V_i, \]

\[ \nabla_{V_i}\mathcal{L}={\Bigg(}\sum_{j=1}^N\phi(Q_j)\Big(\nabla_{\bar{V_i}}\mathcal{L}\Big)^T{\Bigg)}^T\phi(K_i). \]

式9、式13-15的累加和是在线性时间内、仅需关于序列长度线性比的内存空间内计算得到的。给定一个\(C\)维的特征图,算法的时间复杂度为\(\mathcal{O}(NCM)\),空间复杂度为\(\mathcal{O}(N\max(C,M))\)

Summary

这篇文章实现了线性复杂度的transformer,后续尝试把线性的transformer加到DETR类模型里跑一下,先从original DETR开始改。Facebook有后续的工作,Hydra Attention,但是还没有开源,先挖个坑后面再看。


[论文阅读] Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention
http://k0145vin.xyz/2022/10/29/论文阅读-Transformers-are-RNNs-Fast-Autoregressive-Transformers-with-Linear-Attention/
作者
一瓶AD钙
发布于
2022年10月29日
许可协议