深度学习中的状态空间模型(SSM)初探

引言

前几天,笔者看了几篇介绍SSM(State Space Model)的文章,才发现原来自己从未认真了解过SSM,于是打算认真去学习一下SSM的相关内容,顺便开了这个新坑,记录一下学习所得。

SSM的概念由来已久,但这里我们特指深度学习中的SSM,一般认为其开篇之作是2021年的S4,不算太老,而SSM最新最火的变体大概是去年的Mamba。当然,当我们谈到SSM时,也可能泛指一切线性RNN模型,这样RWKV、RetNet还有此前我们在《Google新作试图“复活”RNN:RNN能否再次辉煌?》介绍过的LRU都可以归入此类。不少SSM变体致力于成为Transformer的竞争者,尽管笔者并不认为有完全替代的可能性,但SSM本身优雅的数学性质也值得学习一番。

尽管我们说SSM起源于S4,但在S4之前,SSM有一篇非常强大的奠基之作《HiPPO: Recurrent Memory with Optimal Polynomial Projections》(简称HiPPO),所以本文从HiPPO开始说起。

基本形式

先插句题外话,上面提到的SSM代表作HiPPO、S4、Mamba的一作都是Albert Gu,他还有很多篇SSM相关的作品,毫不夸张地说,这些工作筑起了SSM大厦的基础。不论SSM前景如何,这种坚持不懈地钻研同一个课题的精神都值得我们由衷地敬佩。

言归正传。对于事先已经对SSM有所了解的读者,想必知道SSM建模所用的是线性ODE系统:

[latex][ x′(t) = Ax(t) + Bu(t) ][/latex]
[latex][ y(t) = Cx(t) + Du(t) ][/latex]

其中 [latex]( u(t) \in \mathbb{R}^{d_i}, x(t) \in \mathbb{R}^d, y(t) \in \mathbb{R}^{d_o}, A \in \mathbb{R}^{d \times d}, B \in \mathbb{R}^{d \times d_i}, C \in \mathbb{R}^{d_o \times d}, D \in \mathbb{R}^{d_o \times d_i} )[/latex]。当然我们也可以将它离散化,那么就变成一个线性RNN模型,这部分我们在后面的文章再展开。不管离散化与否,其关键词都是“线性”,那么马上就有一个很自然的问题:为什么是线性系统?线性系统够了吗?

我们可以从两个角度回答这个问题:线性系统既足够简单,也足够复杂。简单是指从理论上来说,线性化往往是复杂系统的一个最基本近似,所以线性系统通常都是无法绕开的一个基本点;复杂是指即便如此简单的系统,也可以拟合异常复杂的函数,为了理解这一点,我们只需要考虑一个 [latex]( \mathbb{R}^4 )[/latex] 的简单例子:

[latex][ x′(t) = \begin{pmatrix} 1 & 0 & 0 & 0 \ 0 & -1 & 0 & 0 \ 0 & 0 & 0 & 1 \ 0 & 0 & -1 & 0 \end{pmatrix} x(t) ][/latex]

这个例子的基本解是 [latex]( x(t) = (e^t, e^{-t}, \sin t, \cos t) )[/latex]。这意味着只要 [latex]( d )[/latex] 足够大,该线性系统就可以通过指数函数和三角函数的组合来拟合足够复杂的函数,而我们知道拟合能力很强的傅里叶级数也只不过是三角函数的组合,如果再加上指数函数显然就更强了,因此可以想象线性系统也有足够复杂的拟合能力。

当然,这些解释某种意义上都是“马后炮”。HiPPO给出的结果更加本质:当我们试图用正交基去逼近一个动态更新的函数时,其结果就是如上的线性系统。这意味着,HiPPO不仅告诉我们线性系统可以逼近足够复杂的函数,还告诉我们怎么去逼近,甚至近似程度如何。

有限压缩

接下来,我们只考虑 [latex]( d_i = 1 )[/latex] 的特殊情形,[latex]( d_i > 1 )[/latex] 只不过是 [latex]( d_i = 1 )[/latex] 时的并行推广。此时,[latex]( u(t) )[/latex] 的输出是一个标量,进一步地,作为开头我们先假设 [latex]( t \in [0, 1] )[/latex],HiPPO的目标是:用一个有限维的矢量来储存这一段 [latex]( u(t) )[/latex] 的信息。

看上去这是一个不大可能的需求,因为 [latex]( t \in [0, 1] )[/latex] 意味着 [latex]( u(t) )[/latex] 可能相当于无限个点组成的矢量,压缩到一个有限维的矢量可能严重失真。不过,如果我们对 [latex]( u(t) )[/latex] 做一些假设,并且允许一些损失,那么这个压缩是有可能做到的,并且大多数读者都已经尝试过。比如,当 [latex]( u(t) )[/latex] 在某点 ( n+1 ) 阶可导的,它对应的 ( n ) 阶泰勒展开式往往是 [latex]( u(t) )[/latex] 的良好近似,于是我们可以只储存展开式的 ( n+1 ) 个系数来作为 [latex]( u(t) )[/latex] 的近似表征,这就成功将 [latex]( u(t) )[/latex] 压缩为一个 ( n+1 ) 维矢量。

当然,对于实际遇到的数据来说,“( n+1 ) 阶可导”这种条件可谓极其苛刻,我们通常更愿意使用在平方可积条件下的正交函数基展开,比如傅里叶(Fourier)级数,它的系数计算公式为:

[latex][ c_n = \int_0^1 u(t) e^{-2i\pi nt} dt ][/latex]

这时候取一个足够大的整数 [latex]( N. [/latex],只保留 [latex]( |n| \leq N )[/latex] 的系数,那么就将 [latex]( u(t) )[/latex] 压缩为一个 ( 2N+1 ) 维的矢量了。

接下来,问题难度就要升级了。刚才我们说 [latex]( t \in [0, 1] )[/latex],这是一个静态的区间,而实际中 [latex]( u(t) )[/latex] 代表的是持续采集的信号,所以它是不断有新数据进入的,比如现在我们近似了 [latex]( [0, 1] )[/latex] 区间的数据,马上就有 [latex]( [1, 2] )[/latex] 的数据进来,你需要更新逼近结果来试图记忆整个 ( [0, 2] ) 区间,接下来是 ( [0, 3] )、( [0, 4] ) 等等,这我们称为“在线函数逼近”。而上面的傅里叶系数公式只适用于区间 ( [0, 1] ),因此需要将它进行推广。

为此,我们设 [latex]( t \in [0, T] ),( s \mapsto t \leq T(s) )[/latex] 是 ( [0, 1] ) 到 ( [0, T] ) 的一个映射,那么 [latex]( u(t \leq T(s)) )[/latex] 作为 ( s ) 的函数时,它的定义区间就是 [latex]( [0, 1] )[/latex],于是就可以复用傅里叶系数公式:

[latex][ c_n(T. = \int_0^1 u(t \leq T(s)) e^{-2i\pi ns} ds ][/latex]

这里我们已经给系数加了标记 [latex]( (T. )[/latex],以表明此时的系数会随着 [latex]( T )[/latex] 的变化而变化。

线性初现

能将 [latex]( [0, 1] )[/latex] 映射到 [latex]( [0, T] )[/latex] 的函数有无穷多,而最终结果也因 [latex]( t \leq T(s) )[/latex] 而异,一些比较直观且相对简单的选择如下:

  1. [latex]( t \leq T(s) = sT )[/latex],即将 ( [0, 1] ) 均匀地映射到 ( [0, T] );
  2. 注意 [latex]( t \leq T(s) )[/latex] 并不必须是满射,所以像 [latex]( t \leq T(s) = s + T – 1 )[/latex] 也是允许的,这意味着只保留了最邻近窗口 ( [T – 1, T] ) 的信息,丢掉了更早的部分,更一般地有 [latex]( t \leq T(s) = sw + T – w )[/latex],其中 ( w ) 是一个常数,这意味着 ( T – w ) 前的信息被丢掉了;
  3. 也可以选择非均匀映射,比如 [latex]( t \leq T(s) = T\sqrt{s} )[/latex],它同样是 ( [0, 1] ) 到 ( [0, T] ) 的满射,但 ( s = 1/4 ) 时就映射到 ( T/2 ) 了,这意味着我们虽然关注全局的历史,但同时更侧重于 ( T. 时刻附近的信息。

现在我们以 [latex]( t \leq T(s) = (s + 1)w/2 + T – w )[/latex] 为例,代入傅里叶系数公式得到:

[latex][ c_n(T. = \int_0^1 u(sw + T – w) e^{-2i\pi ns} ds ][/latex]

现在我们两边求关于 ( T. 的导数:

[latex][ \frac{d}{dT} c_n(T. = \frac{1}{w} \left[ u(T) – u(T – w) \right] + \frac{2i\pi n}{w} c_n(T) ][/latex]

其中我们用了分部积分公式。由于我们只保留了 [latex]( |n| \leq N. [/latex] 的系数,所以根据傅里叶级数的公式,可以认为如下是 [latex]( u(sw + T – w) )[/latex] 的一个良好近似:

[latex][ u(sw + T – w) \approx \sum_{k=-N}^N c_k(T. e^{2i\pi ks} ][/latex]

那么 [latex]( u(T – w) = u(sw + T – w) \big|{s=0} \approx \sum{k=-N}^N c_k(T. )[/latex],代入上式得:

[latex][ \frac{d}{dT} c_n(T. \approx \frac{1}{w} \left[ u(T) – \sum_{k=-N}^N c_k(T) \right] + \frac{2i\pi n}{w} c_n(T) ][/latex]

将 ( T. 换成 ( t ),然后所有的 [latex]( c_n(t) )[/latex] 堆在一起记为 [latex]( x(t) = (c_{-N}, c_{-(N-1)}, \ldots, c_0, \ldots, c_{N-1}, c_N) )[/latex],并且不区分 [latex]( \approx )[/latex] 和 [latex]( = )[/latex],那么就可以写出:

[latex][ x′(t) = A x(t) + B u(t) ][/latex]

其中:

[latex][ A_{n,k} = \begin{cases} \frac{2i\pi n – 1}{w}, & k = n \ -\frac{1}{w}, & k \ne n \end{cases}, \quad B_n = \frac{1}{w} ][/latex]

这就出现了如上所示的线性ODE系统。即当我们试图用傅里叶级数去记忆一个实时函数的最邻近窗口内的状态时,结果自然而然地导致了一个线性ODE系统。

一般框架

当然,目前只是选择了一个特殊的 [latex]( t \leq T(s) )[/latex],换一个 [latex]( t \leq T(s) )[/latex] 就不一定有这么简单的结果了。此外,傅里叶级数的结论是在复数范围内的,进一步实数化也可以,但形式会变得复杂起来。所以,我们要将这一过程推广成一个一般化的框架,从而得到更一般、更简单的纯实数结论。

设 [latex]( t \in [a, b] )[/latex],并且有目标函数 [latex]( u(t) )[/latex] 和函数基 [latex]( { g_n(t) }_{n=0}^N. [/latex],我们希望有后者的线性组合来逼近前者,目标是最小化 [latex]( L^2 )[/latex] 距离:

[latex][ \arg\min_{c_1, \ldots, c_N} \int_a^b \left[ u(t) – \sum_{n=0}^N c_n g_n(t) \right]^2 dt ][/latex]

这里我们主要在实数范围内考虑,所以方括号直接平方就行,不用取模。更一般化的目标函数还可以再加个权重函数 [latex]( \rho(t) )[/latex],但我们这里就不考虑了,毕竟HiPPO的主要结论其实也没考虑这个权重函数。

对目标函数展开,得到:

[latex][ \int_a^b u^2(t) dt – 2 \sum_{n=0}^N c_n \int_a^b u(t) g_n(t) dt + \sum_{m=0}^N \sum_{n=0}^N c_m c_n \int_a^b g_m(t) g_n(t) dt ][/latex]

这里我们只考虑标准正交函数基,其定义为 [latex]( \int_a^b g_m(t) g_n(t) dt = \delta_{m,n} ),( \delta_{m,n} )[/latex] 是克罗内克δ函数,此时上式可以简化成:

[latex][ \int_a^b u^2(t) dt – 2 \sum_{n=0}^N c_n \int_a^b u(t) g_n(t) dt + \sum_{n=0}^N c_n^2 ][/latex]

这只是一个关于 [latex]( c_n )[/latex] 的二次函数,它的最小值是有解析解的:

[latex][ c^*_n = \int_a^b u(t) g_n(t) dt ][/latex]

这也被称为 [latex]( u(t) )[/latex] 与 [latex]( g_n(t) )[/latex] 的内积,它是有限维矢量空间的内积到函数空间的并行推广。简单起见,在不至于混淆的情况下,我们默认 [latex]( c_n )[/latex] 就是 [latex]( c^*_n )[/latex]。

接下来的处理跟上一节是一样的,我们要对一般的 [latex]( t \in [0, T] )[/latex] 考虑 [latex]( u(t) )[/latex] 的近似,那么找一个 [latex]( [a, b] )[/latex] 到 [latex]( [0, T] )[/latex] 的映射 [latex]( s \mapsto t \leq T(s) )[/latex],然后计算系数:

[latex][ c_n(T. = \int_a^b u(t \leq T(s)) g_n(s) ds ][/latex]

同样是两边求 [latex]( T. [/latex] 的导数,然后用分部积分法:

[latex][ \frac{d}{dT} c_n(T. = \int_a^b u'(t \leq T(s)) \frac{\partial t \leq T(s)}{\partial T} g_n(s) ds ][/latex]

[latex][ = \int_a^b \left( \frac{\partial t \leq T(s)}{\partial T} / \frac{\partial t \leq T(s)}{\partial s} \right) g_n(s) du(t \leq T(s)) ][/latex]

[latex][ = \left( \frac{\partial t \leq T(s)}{\partial T} / \frac{\partial t \leq T(s)}{\partial s} \right) g_n(s) \bigg|_{s=b}^{s=a} – \int_a^b u(t \leq T(s)) d \left[ \left( \frac{\partial t \leq T(s)}{\partial T} / \frac{\partial t \leq T(s)}{\partial s} \right) g_n(s) \right] ][/latex]

请勒让德

接下来的计算,就依赖于 [latex]( g_n(t) )[/latex] 和 [latex]( t \leq T(s) )[/latex] 的具体形式了。HiPPO的全称是High-order Polynomial Projection Operators,第一个P正是多项式(Polynomial)的首字母,所以HiPPO的关键是选取多项式为基。现在我们请出继傅里叶之后又一位大牛——勒让德(Legendre),接下来我们要选取的函数基正是以他命名的“勒让德多项式”。

勒让德多项式 [latex]( p_n(t) )[/latex] 是关于 [latex]( t )[/latex] 的 [latex]( n )[/latex] 次函数,定义域为 [latex]( [-1, 1] )[/latex],满足:

[latex][ \int_{-1}^1 p_m(t) p_n(t) dt = \frac{2}{2n+1} \delta_{m,n} ][/latex]

所以 [latex]( p_n(t) )[/latex] 之间只是正交,还不是标准(平分积分为1),[latex]( g_n(t) = \sqrt{\frac{2n+1}{2}} p_n(t) )[/latex] 才是标准正交基。

当我们对函数基 [latex]( {1, t, t^2, \ldots, t^n} )[/latex] 执行施密特正交化时,其结果正是勒让德多项式。相比傅里叶基,勒让德多项式的好处是它是纯粹定义在实数空间中的,并且多项式的形式能够有助于简化部分 [latex]( t \leq T(s) )[/latex] 的推导过程,这一点我们后面就可以看到。勒让德多项式有很多不同的定义和性质,这里我们不一一展开,有兴趣的读者自行看维基百科介绍即可。

接下来我们用到两个递归公式来推导一个恒等公式,这两个递归公式是:

[latex][ p′_{n+1}(t) – p′_{n-1}(t) = (2n+1) p_n(t) ][/latex]
[latex][ p′_{n+1}(t) = (n+1) p_n(t) + t p′_n(t) ][/latex]

由第一个公式迭代得到:

[latex][ p′_{n+1}(t) = (2n+1) p_n(t) + (2n−3) p_{n−2}(t) + (2n-7) p_{n-4}(t) + ⋯ = \sum_{k=0}^n (2k+1) χ_{n−k} p_k(t) ][/latex]

其中当 [latex]( k )[/latex] 是偶数时 [latex]( χ_k = 1 )[/latex] 否则 [latex]( χ_k = 0 )[/latex] 。代入第二个公式得到:

[latex][ t p′_n(t) = n p_n(t) + (2n−3) p{n-2}(t) + (2n-7) p_{n-4}(t) + ⋯ ][/latex]

继而有:

[latex][ (t+1) p′_n(t) = n p_n(t) + (2n-1) p_{n−1}(t) + (2n-3) p_{n-2}(t) + ⋯ = – (n+1) p_n(t) + \sum_{k=0}^n (2k+1) p_k(t) ][/latex]

这些就是等会要用到的恒等式。此外,勒让德多项式满足 ( p_n(1) = 1, p_n(-1) = (-1)^n ),这个边界值后面也会用到。

正如 ( n ) 维空间中不止有一组正交基也一样,正交多项式也不止有勒让德多项式一种,比如还有切比雪夫(Chebyshev)多项式,如果算上加权的目标函数(即 [latex]( ρ(t) ≢ 1 )[/latex] ),还有拉盖尔多项式等,这些在原论文中都有提及,但HiPPO的主要结论还是基于勒让德多项式展开的,所以剩余部分这里也不展开讨论了。

邻近窗口

完成准备工作后,我们就可以代入具体的 [latex]( t ≤ T(s) )[/latex] 进行计算了,计算过程跟傅里叶级数的例子大同小异,只不过基函数换成了勒让德多项式构造的标准正交基 [latex]( g_n(t) = \sqrt{\frac{2n+1}{2}} p_n(t) )[/latex]。作为第一个例子,我们同样先考虑只保留最邻近窗口的信息,此时 [latex]( t ≤ T(s) = \frac{(s+1)w}{2} + T – w )[/latex] 将 [latex]( [−1, 1] )[/latex] 映射到 [latex]( [T−w, T] )[/latex],原论文将这种情形称为“LegT(Translated Legendre)”。

直接代入之前得到的公式,马上得到:

[latex][ \frac{d}{dT} c_n(T. = \sqrt{\frac{2(2n+1)}{w}} [u(T) – (−1)^n u(T−w)] – \frac{2}{w} \int_{-1}^1 u\left(\frac{(s+1)w}{2} + T – w\right) g′_n(s) ds ][/latex]

我们首先处理 [latex]( u(T−w) )[/latex] 项,跟傅里叶级数那里同样的思路,我们截断 ( n ≤ N. 作为 [latex]( u\left(\frac{(s+1)w}{2} + T – w\right) )[/latex] 的一个近似:

[latex][ u\left(\frac{(s+1)w}{2} + T – w\right) ≈ \sum_{k=0}^N c_k(T. g_k(s) ][/latex]

从而有 [latex]( u(T−w) ≈ \sum_{k=0}^N c_k(T. g_k(−1) = \sum_{k=0}^N (−1)^k c_k(T) \sqrt{\frac{2k+1}{2}} )[/latex] 。接着,利用之前的递归公式得到:

[latex][ \int_{-1}^1 u\left(\frac{(s+1)w}{2} + T – w\right) g′n(s) ds = \int{-1}^1 u\left(\frac{(s+1)w}{2} + T – w\right) \sqrt{\frac{2n+1}{2}} p′_n(s) ds ][/latex]

[latex][ = \int_{-1}^1 u\left(\frac{(s+1)w}{2} + T – w\right) \sqrt{\frac{2n+1}{2}} \left[\sum_{k=0}^{n-1} (2k+1) \chi_{n-1-k} p_k(s) \right] ds ][/latex]

[latex][ = \int_{-1}^1 u\left(\frac{(s+1)w}{2} + T – w\right) \sqrt{\frac{2n+1}{2}} \left[\sum_{k=0}^{n-1} \sqrt{2(2k+1)} \chi_{n-1-k} g_k(s) \right] ds ][/latex]

[latex][ = \sqrt{\frac{2n+1}{2}} \sum_{k=0}^{n-1} \sqrt{2(2k+1)} \chi_{n-1-k} c_k(T. ][/latex]

将这些结果集成起来,就有:

[latex][ \frac{d}{dT} c_n(T. ≈ \sqrt{\frac{2(2n+1)}{w}} u(T) – \sqrt{\frac{2(2n+1)}{w}} (−1)^n \sum_{k=0}^N (−1)^k c_k(T) \sqrt{\frac{2k+1}{2}} – \frac{2}{w} \sqrt{\frac{2n+1}{2}} \sum_{k=0}^{n-1} \sqrt{2(2k+1)} \chi_{n-1-k} c_k(T) ][/latex]

再次地,将 [latex]( T. [/latex] 换回 [latex]( t )[/latex],并将所有的 [latex]( c_n(t) )[/latex] 堆在一起记为 [latex]( x(t) = (c_0, c_1, ⋯, c_N) )[/latex],那么根据上式可以写出:

[latex][ x′(t) = A x(t) + B u(t) ][/latex]

其中:

[latex][ A_{n,k} = \begin{cases} -\frac{2}{w} \sqrt{\frac{(2n+1)(2k+1)}{2}}, & k < n \ -\frac{2}{w} \sqrt{\frac{(2n+1)(2k+1)}{2}} (−1)^{n−k}, & k ≥ n \end{cases}, \quad B_n = \sqrt{\frac{2(2n+1)}{w}} ][/latex]

我们还可以给每个 [latex]( c_n(T. )[/latex] 都引入一个缩放因子,来使得上述结果更一般化。比如我们设 [latex]( c_n(T) = λ_n \tilde{c}_n(T) )[/latex],代入上式整理得:

[latex][ \frac{d}{dT} \tilde{c}n(T. ≈ \sqrt{\frac{2(2n+1)}{w}} \frac{u(T)}{λ_n} – \sqrt{\frac{2(2n+1)}{w}} \frac{(−1)^n}{λ_n} \sum{k=0}^N \frac{(−1)^k c_k(T)}{λk} \sqrt{\frac{2k+1}{2}} – \frac{2}{w} \sqrt{\frac{2n+1}{2}} \sum{k=0}^{n-1} \frac{λk}{λ_n} \sqrt{2(2k+1)} \chi{n-1-k} \tilde{c}_k(T) ][/latex]

如果取 [latex]( λ_n = \sqrt{2} )[/latex],那么 ( A. 不变,[latex]( B_n = \sqrt{2(2n+1)} )[/latex],这就对齐了原论文的结果。如果取 [latex]( λ_n = \sqrt{\frac{2}{2n+1}} )[/latex],那么就得到了Legendre Memory Units中的结果:

[latex][ x′(t) = A x(t) + B u(t) ][/latex]

其中:

[latex][ A_{n,k} = \begin{cases} 2n+1, & k < n \ (−1)^{n−k} (2n+1), & k ≥ n \end{cases}, \quad B_n = 2n+1 ][/latex]

这些形式在理论上都是等价的,但可能存在不同的数值稳定性。比如一般来说当 [latex]( u(t) )[/latex] 的性态不是特别糟糕时,我们可以预期 [latex]( n )[/latex] 越大,[latex]( |c_n| )[/latex] 的值就相对越小,这样直接用 [latex]( c_n )[/latex] 的话 [latex]( x(t) )[/latex] 矢量的每个分量的尺度就不大对等,这样的系统在实际计算时容易出现数值稳定问题,而取 [latex]( λ_n = \sqrt{\frac{2}{2n+1}} )[/latex] 改用 [latex]( \tilde{c}_n )[/latex] 的话意味着数值小的分量会被适当放大,可能有助于缓解多尺度问题从而使得数值计算更稳定。

整个区间

现在我们继续计算另一个例子:[latex]( t ≤ T(s) = \frac{(s+1)T}{2} )[/latex],它将 [latex]( [−1, 1] )[/latex] 均匀映射到 [latex]( [0, T] )[/latex],这意味着我们没有舍弃任何历史信息,并且平等地对待所有历史,原论文将这种情形称为“LegS(Scaled Legendre)”。

同样地,通过代入之前得到的公式:

[latex][ \frac{d}{dT} c_n(T. = \sqrt{\frac{2(2n+1)}{T}} u(T) – \frac{1}{T} \int_{-1}^1 u\left(\frac{(s+1)T}{2}\right) (s+1) g′_n(s) ds ][/latex]

利用之前的递归公式得到:

[latex][ \int_{-1}^1 u\left(\frac{(s+1)T}{2}\right) (s+1) g′n(s) ds = \int{-1}^1 u\left(\frac{(s+1)T}{2}\right) \left[g_n(s) + (s+1) g′_n(s)\right] ds ][/latex]

[latex][ = c_n(T. + \int_{-1}^1 u\left(\frac{(s+1)T}{2}\right) \sqrt{\frac{2n+1}{2}} p′_n(s) ds ][/latex]

[latex][ = c_n(T. + \int_{-1}^1 u\left(\frac{(s+1)T}{2}\right) \left[-(n+1) g_n(s) + \sum_{k=0}^n \sqrt{(2n+1)(2k+1)} g_k(s)\right] ds ][/latex]

[latex][ = c_n(T. – n c_n(T) + \sum_{k=0}^n \sqrt{(2n+1)(2k+1)} c_k(T) ][/latex]

于是有:

[latex][ \frac{d}{dT} c_n(T. = \sqrt{\frac{2(2n+1)}{T}} u(T) – \frac{1}{T} \left(-n c_n(T) + \sum_{k=0}^n \sqrt{(2n+1)(2k+1)} c_k(T)\right) ][/latex]

将 ( T. 换回 ( t ),将所有的 ( c_n(t) ) 堆在一起记为 [latex]( x(t) = (c_0, c_1, ⋯, c_N) )[/latex],那么根据上式可以写出:

[latex][ x′(t) = A x(t) + B u(t) ][/latex]

其中:

[latex][ A_{n,k} = \begin{cases} \sqrt{(2n+1)(2k+1)}, & k < n \ n+1, & k = n \ 0, & k > n \end{cases}, \quad B_n = \sqrt{2(2n+1)} ][/latex]

引入缩放因子来一般化结果也是可行的:设 [latex]( c_n(T. = λ_n \tilde{c}_n(T) )[/latex],代入上式整理得:

[latex][ \frac{d}{dT} \tilde{c}n(T. = \sqrt{\frac{2(2n+1)}{T}} \frac{u(T)}{λ_n} – \frac{1}{T} \left(-n \tilde{c}_n(T) + \sum{k=0}^n \sqrt{(2n+1)(2k+1)} \frac{λ_k}{λ_n} \tilde{c}_k(T)\right) ][/latex]

取 [latex]( λ_n = \sqrt{\frac{2}{2n+1}} )[/latex],就可以让 ( A. 不变,[latex]( B_n = \sqrt{2(2n+1)} )[/latex],就对齐了原论文的结果。如果取 [latex]( λ_n = \sqrt{\frac{2}{2n+1}} )[/latex],就可以像上一节LegT的结果一样去掉根号:

[latex][ x′(t) = A x(t) + B u(t) ][/latex]

其中:

[latex][ A_{n,k} = \begin{cases} 2(2n+1), & k < n \ n+1, & k = n \ 0, & k > n \end{cases}, \quad B_n = 2(2n+1) ][/latex]

但原论文没有考虑这种情况,原因不详。

延伸思考

回顾Leg-S的整个推导,我们可以发现其中关键一步是将 [latex]( (s+1) g′_n(s) )[/latex] 拆成 [latex]( g_0(s), g_1(s), ⋯, g_n(s) )[/latex] 的线性组合,对于正交多项式来说,[latex]( (s+1) g′_n(s) )[/latex] 是一个 [latex]( n )[/latex] 次多项式,所以这种拆分必然可以精确成立,但如果是傅立叶级数的情况,[latex]( g_n(s) )[/latex] 是指数函数,此时类似的拆分做不到了,至少不能精确地做到,所以可以说选取正交多项式为基的根本目的是简化后面推导。

特别要指出的是,HiPPO是一个自下而上的框架,它并没有一开始就假设系统必须是线性的,而是从正交基逼近的角度反过来推出其系数的动力学满足一个线性ODE系统,这样一来我们就可以确信,只要认可所做的假设,那么线性ODE系统的能力就是足够的,而不用去担心线性系统的能力限制了你的发挥。

当然,HiPPO对于每一个解所做的假设及其物理含义也很清晰,所以对于重用了HiPPO矩阵的SSM,它怎么储存历史、能储存多少历史,从背后的HiPPO假设就一清二楚。比如LegT就是只保留 [latex]( w )[/latex] 大小的最邻近窗口信息,如果你用了LegT的HiPPO矩阵,那么就类似于一个Sliding Window Attention;而LegS理论上可以捕捉全部历史,但这有个分辨率问题,因为 [latex]( x(t) )[/latex] 的维度代表了拟合的阶数,它是一个固定值,用同阶的函数基去拟合另一个函数,肯定是区间越小越准确,区间越大误差也越大,这就好比为了一次性看完一幅大图,那么我们必须站得更远,从而看到的细节越少。

诸如RWKV、LRU等模型,并没有重用HiPPO矩阵,而是改为可训练的矩阵,原则上具有更多的可能性来突破瓶颈,但从前面的分析大致上可以感知到,不同矩阵的线性ODE只是函数基不同,但本质上可能都只是有限阶函数基逼近的系数动力学。既然如此,分辨率与记忆长度就依然不可兼得,想要记忆更长的输入并且保持效果不变,那就只能增加整个模型的体量(即相当于增加hidden_size),这大概是所有线性系统的特性。

文章小结

本文以尽可能简单的方式重复了《HiPPO: Recurrent Memory with Optimal Polynomial Projections》(简称HiPPO)的主要推导。HiPPO通过适当的记忆假设,自下而上地导出了线性ODE系统,并且针对勒让德多项式的情形求出了相应的解析解(HiPPO矩阵),其结果被后来诸多SSM(State Space Model)使用,可谓是SSM的重要奠基之作。

HiPPO框架展现了优雅的数学结构和强大的应用潜力,在处理时间序列数据时提供了一种高效的记忆机制。未来的研究可以进一步探索其在不同领域中的应用和改进。

参考文献: https://spaces.ac.cn/archives/10114

发表评论

人生梦想 - 关注前沿的计算机技术 acejoy.com