当 Transformer 遇上状态空间模型:结构化状态空间对偶性的启示

近年来,深度学习在自然语言处理领域取得了巨大成功,这主要归功于 Transformer 架构。然而,状态空间模型(SSM),例如 Mamba,最近被证明在中小型规模上可以与 Transformer 媲美甚至超越。本文将深入探讨这两种模型之间的密切关系,并通过对结构化半可分矩阵的不同分解,建立 SSM 和注意力变体之间丰富的理论联系框架。我们的状态空间对偶性(SSD)框架将引领我们设计一种新的架构(Mamba-2),其核心层是对 Mamba 选择性 SSM 的改进,速度提高了 2-8 倍,同时在语言建模方面继续与 Transformer 保持竞争力。

Transformer 的效率瓶颈与状态空间模型的崛起

Transformer,特别是仅解码器模型(例如 GPT 和 Llama),以因果方式处理输入序列,是现代深度学习成功的关键驱动力之一。然而,其核心注意力层存在效率问题,例如在训练期间按序列长度呈二次方增长,以及在自回归生成期间需要大小与序列长度呈线性关系的缓存。为了解决这些问题,许多方法试图近似核心注意力层(Tay et al. 2022),但效果有限。

与此同时,一类替代序列模型——结构化状态空间模型(SSM)——应运而生。它们在训练期间按序列长度呈线性增长,在生成期间具有恒定的状态大小。SSM 在长程任务上表现出色(例如 S4),并且最近在中小型规模的语言建模上与 Transformer 媲美甚至超越(例如 Mamba)。然而,SSM 的发展似乎与社区为改进 Transformer 所做的集体努力脱节,例如从理论上理解它们以及在现代硬件上优化它们。因此,与 Transformer 相比,理解和试验 SSM 更加困难,并且从算法和系统角度来看,高效地训练 SSM 仍然具有挑战性。

结构化状态空间对偶性:连接 SSM 和注意力的桥梁

本文的主要目标是建立结构化 SSM 和注意力变体之间丰富的理论联系。这将使我们能够将最初为 Transformer 开发的算法和系统优化转移到 SSM,从而构建性能优于 Transformer 且序列长度扩展效率更高的基础模型。线性注意力(LA)框架(Katharopoulos et al. 2020)是朝着这个方向迈出的里程碑式的一步,它通过证明二次核化注意力的“对偶形式”与特定线性递归之间的等价性,推导了自回归注意力和线性 RNN 之间的联系。这种对偶性带来了新的能力,例如能够同时进行高效的可并行化训练和高效的自回归推理。

本文秉承同样的精神,提供了多个视角,将线性复杂度的 SSM 与二次复杂度的形式联系起来,以结合 SSM 和注意力的优势。我们的框架将结构化 SSM 和注意力变体联系起来,我们称之为结构化状态空间对偶性(SSD),它是通过结构化矩阵的抽象实现的:具有次二次参数和乘法复杂度的矩阵。我们开发了两个广泛的框架来表示序列模型,一个作为矩阵变换,一个作为张量收缩,每个框架都揭示了对偶性的不同视角。

状态空间模型与半可分矩阵的等价性

本文的核心是状态空间模型和一类称为半可分矩阵的结构化矩阵之间的等价性(第 3 节)。这种联系揭示了 SSM 的新特性和算法。本文的一个中心思想是,计算状态空间模型的不同方法可以重新表述为结构化矩阵上的各种矩阵乘法算法。

半可分矩阵是一种基本的矩阵结构。我们首先定义这些矩阵及其性质。

定义 3.1:如果包含在下三角部分(即对角线或下方)的每个子矩阵的秩最多为 N,则(下三角)矩阵 M 为 N-半可分矩阵。我们称 N 为半可分矩阵的阶数或秩。

半可分矩阵具有许多结构化表示,包括分层半可分 (HSS)、顺序半可分 (SSS) 和 Bruhat 形式。我们将主要使用 SSS 形式。

定义 3.2:如果下三角矩阵 M ∈ R(T,T) 可以写成以下形式,则它具有 N-顺序半可分 (SSS) 表示:

$$
M_{ji} = C_j^\top A_j \cdots A_{i+1} B_i
$$

其中向量 $B_0, …, B_{T-1}, C_0, …, C_{T-1} ∈ R^N$ 且矩阵 $A_0, …, A_{T-1} ∈ R^{(N,N)}$。

半可分矩阵的一个基本结果是它们与具有 SSS 表示的矩阵完全等价。一个方向可以通过简单的构造性证明来推断。

引理 3.3:具有表示 (4) 的 N-SSS 矩阵 M 是 N-半可分矩阵。

命题 3.4:每个 N-半可分矩阵都有一个 N-SSS 表示。

状态空间模型和顺序半可分矩阵之间的联系如下:

定理 3.5:状态大小为 N 的状态空间模型变换 𝑦 = SSM(𝐴, 𝐵, 𝐶)(𝑥) 与通过 N-SS 矩阵(以顺序半可分表示)进行矩阵乘法 𝑦 = SSS(𝐴, 𝐵, 𝐶) · 𝑥 相同。

换句话说,序列变换算子 SSM 与矩阵构造算子 SSS 重合,我们可以互换使用它们。

通过结构化矩阵算法计算状态空间模型

定理 3.5 的重要性在于它使我们能够将高效计算 SSM(和其他序列模型)的问题简化为结构化矩阵乘法的有效算法。

命题 3.6 (Pernet, Signargout, and Villard (2023)):大小为 T 的 N-SS 矩阵可以用 𝑂 (NT) 个参数表示,并且矩阵向量乘法的时间和空间复杂度为 𝑂 (NT)。

定理 3.7:任何状态大小为 N、序列长度为 T 的状态空间模型(定义 2.2)都可以在 𝑂 (TN) 时间内计算(不考虑潜在的预处理)。

结构化掩码注意力:用结构化矩阵推广线性注意力

在本节中,我们将从第一性原理重新审视线性注意力框架。本节的主要结果是基于张量收缩的线性注意力简单证明(命题 4.1),以及我们在定义 4.2 中对结构化掩码注意力的概括抽象。

线性注意力和其他许多高效注意力变体通常是通过改变核心注意力计算 (𝑄𝐾 ⊤)𝑉 = 𝑄 (𝐾 ⊤𝑉 ) 中矩阵结合性的顺序来实现的。但是,当添加掩码时,推导就有点不那么直接了。

命题 4.1 ((Katharopoulos et al. 2020)):自回归核注意力(即具有因果掩码的掩码核注意力)可以通过每次迭代花费恒定时间的递归在 𝑂 (𝑇 ) 时间内计算。

通过使用张量收缩的视角,我们可以立即看到原始线性注意力的关键在于因果掩码的矩阵向量乘法等价于累积和运算符。

然而,我们观察到注意力掩码不必都是 1。线性注意力快速计算所必需的只是 𝐿 是一个结构化矩阵,根据定义,结构化矩阵具有快速的矩阵乘法(第 2.3 节)。特别是,我们可以使用任何具有次二次(理想情况下是线性)矩阵向量乘法的掩码矩阵 𝐿,通过加速瓶颈方程 (15b),它将具有与标准线性注意力相同的复杂度。

定义 4.2:结构化掩码注意力 (SMA)(或简称结构化注意力)定义为查询/键/值 𝑄、𝐾、𝑉 以及任何结构化矩阵 𝐿(即具有次二次矩阵乘法)上的函数,通过 4 路张量收缩:

$$
𝑌 = \text{contract}(TN, SN, SP, TS → TP)(𝑄, 𝐾, 𝑉 , 𝐿).
$$

SMA 二次模式算法是由 (13) 定义的成对收缩序列,它对应于标准(掩码)注意力计算。

SMA 线性模式算法是由 (15) 定义的成对收缩序列,其中步骤 (15b) 通过次二次结构化矩阵乘法进行了优化。

状态空间对偶性

在第 3 节和第 4 节中,我们定义了结构化状态空间模型和结构化注意力,讨论了它们的性质,并表明它们都具有二次算法和线性算法。本节将它们联系在一起。我们的主要结果是表明结构化状态空间模型的一个特例与结构化注意力的一个特例重合,并且线性时间 SSM 算法和二次时间核注意力算法是彼此的对偶形式。

为了总结我们的结果:

  • 结构化状态空间模型(第 3 节)通常通过线性时间递归来定义。然而,通过扩展表征其线性序列到序列变换的矩阵公式,可以推导出二次形式。
  • 注意力变体(第 4 节)是通过二次时间成对交互定义的模型。然而,通过将其视为四路张量收缩并以不同的顺序进行约简,可以推导出线性形式。
  • 每个模型的一个自然特例——更准确地说,在 𝐴 矩阵上具有标量恒等结构的状态空间模型,以及在其 𝐿 掩码上具有 1-半可分结构的结构化掩码注意力——是彼此的对偶,具有完全相同的线性和二次形式。

图 4 总结了这两种表示之间的对偶性。

SSD 模型的硬件高效算法

开发 SSM、注意力和结构化矩阵之间的理论 SSD 框架的好处在于利用这些联系来改进模型和算法。在本节中,我们将展示如何从计算结构化矩阵乘法的各种算法中推导出高效计算 SSD 模型的各种算法。

我们的主要计算结果是一种结合了线性(递归)模式和二次(注意力)模式的 SSD 模型计算算法。该算法的计算效率与 SSM(序列长度的线性缩放)一样高,并且与注意力(主要使用矩阵乘法)一样对硬件友好。

定理 6.1:考虑一个状态扩展因子为 N、头部维度为 P = N 的 SSD 模型。存在一种算法可以在任何输入 𝑋 ∈ R(T,P) 上计算模型,该算法只需要 𝑂 (TN2) 个训练 FLOP、𝑂 (TN) 个推理 FLOP、𝑂 (N2) 个推理内存,并且其工作主要由矩阵乘法主导。

定理 6.1 背后的主要思想是再次将计算状态空间模型的问题视为半可分矩阵乘法,但以一种新的方式利用其结构。我们没有以递归或注意力模式计算整个矩阵,而是对矩阵进行块分解。对角块可以使用对偶注意力模式计算,这可以通过矩阵乘法有效地完成,而非对角块可以通过半可分矩阵的秩结构进行分解,并简化为更小的递归。

Mamba-2 架构

通过连接 SSM 和注意力,SSD 框架使我们能够为两者开发共享的词汇表和技术库。在本节中,我们将讨论使用最初为 Transformer 开发的思想来理解和修改 SSD 层的一些示例。我们将讨论几个设计选择,从而形成 Mamba-2 架构。

块设计

我们首先讨论与内部序列混合层无关的神经网络块的修改(即核心 SSD 层之外)。

并行参数投影:Mamba-1 的动机是以 SSM 为中心的观点,其中选择性 SSM 层被视为从 𝑋 ↦→ 𝑌 的映射。SSM 参数 𝐴、𝐵、𝐶 被视为辅助参数,并且是 SSM 输入 𝑋 的函数。因此,定义 (𝐴, 𝐵, 𝐶) 的线性投影发生在创建 𝑋 的初始线性投影之后。

在 Mamba-2 中,SSD 层被视为从 𝐴、𝑋、𝐵、𝐶 ↦→ 𝑌 的映射。因此,在块的开头使用单个投影并行生成 𝐴、𝑋、𝐵、𝐶 是有意义的。请注意与标准注意力架构的类比,其中 𝑋、𝐵、𝐶 对应于并行创建的 𝑄、𝐾、𝑉 投影。

额外的归一化:在初步实验中,我们发现大型模型容易出现不稳定性。我们能够通过在最终输出投影之前的块中添加一个额外的归一化层(例如 LayerNorm、GroupNorm 或 RMSNorm)来缓解这种情况。

序列变换的多头模式

回想一下,SSM 被定义为序列变换(定义 2.1),其中:

  • 𝐴、𝐵、𝐶 参数具有状态维度 N。
  • 它们定义了一个序列变换 RT → RT,例如可以表示为矩阵 𝑀 ∈ R(T,T)。
  • 此变换在输入序列 𝑋 ∈ R(T,P) 上运行,在 P 轴上独立运行。

可以将其视为定义序列变换的一个头。

定义 7.1(多头模式):多头序列变换由 H 个独立的头组成,总模型维度为 D = d_model。参数可以在头部之间绑定,从而形成头部模式。

状态大小 N 和头部维度 P 分别类似于注意力的 𝑄𝐾 头部维度和 𝑉 头部维度。就像在现代 Transformer 架构中一样,在 Mamba-2 中,我们通常选择这些常量在 64 或 128 左右;当模型维度 D 增加时,我们增加头的数量,同时保持头部维度 N 和 P 固定。为了描述如何做到这一点,我们可以转移和概括多头注意力的思想,为 SSM 或任何一般序列变换定义类似的模式。

多头 SSM (MHS) / 多头注意力 (MHA) 模式:经典的 MHA 模式假设头部维度 P 可以整除模型维度 D。头的数量定义为 H = D/P。然后,通过创建每个参数的 H 个独立副本,创建核心序列变换的 H 个副本。

多收缩 SSM (MCS) / 多查询注意力 (MQA) 模式:多查询注意力 (Shazeer 2019) 是一种巧妙的注意力优化方法,可以显着提高自回归推理的速度,这依赖于缓存 𝐾 和 𝑉 张量。这种技术只是避免给 𝐾 和 𝑉 额外的头部维度,或者换句话说,在 𝑄 的所有头部广播 (𝐾, 𝑉) 的单个头部。

使用状态空间对偶性,我们可以将 MQA 的等效 SSM 版本定义为方程 (18)。在这里,𝑋 和 𝐵(注意力的 𝑉 和 𝐾 的 SSM 类比)在 H 个头部之间共享。我们也称其为多收缩 SSM (MCS) 头部模式,因为控制 SSM 状态收缩的 𝐶 参数每个头部都有独立的副本。

我们也可以类似地定义多键注意力 (MKA) 或多扩展 SSM (MES) 头部模式,其中 𝐵(控制 SSM 扩展)每个头部都是独立的,而 𝐶 和 𝑋 在头部之间共享。

多输入 SSM (MIS) / 多值注意力 (MVA) 模式:虽然 MQA 对于注意力来说是有意义的,因为它有 KV 缓存,但它不是 SSM 的自然选择。相反,在 Mamba 中,𝑋 被视为 SSM 的主要输入,因此 𝐵 和 𝐶 是在输入通道之间共享的参数。我们在方程 (20) 中定义了一种新的多值注意力 (MVA) 或多输入 SSM (MIS) 模式,它可以再次应用于任何序列变换,例如 SSD。

有了这个词汇表,我们可以更精确地描述原始的 Mamba 架构。

命题 7.2:Mamba 架构的选择性 SSM (S6) 层可以被视为具有:

  • 头部维度 𝑃 = 1:每个通道都有独立的 SSM 动态 𝐴。
  • 多输入 SSM (MIS) 或多值注意力 (MVA) 头部结构:𝐵、𝐶 矩阵(对应于注意力中的 𝐾、𝑄)在输入 𝑋 的所有通道(对应于注意力中的 𝑉)之间共享。

从线性注意力中获得的其他 SSD 扩展

我们在这里描述一个对 SSD 的架构修改示例,其动机来自线性注意力。

核注意力近似于 Softmax 注意力:许多线性注意力或核注意力变体都是通过将注意力分数 softmax(𝑄𝐾 ⊤) 视为由以下两部分组成来实现的:

  1. 指数核 𝑍 = exp(𝑄𝐾 ⊤),可以通过 𝑍 = 𝜓 (𝑄)𝜓 (𝐾)⊤ 来近似,其中 𝜓 为核特征映射。
  2. 通过 𝑀 = 𝐺/𝐺11⊤ 对核进行归一化,使行总和为 1,其中除法按元素进行,1 为全 1 向量。

在 Mamba-2 中,我们合并了一个灵活的核特征映射,并将其应用于 𝐵 和 𝐶 分支(对应于注意力中的 𝐾 和 𝑉 分支)。为了简单和对称,也可以选择将特征映射应用于 𝑋 (𝑉) 分支。这在图 6 中由任意非线性表示。默认情况下,我们简单地选择 𝜓 为元素级 Swish / SiLU 函数。

合并归一化(分母)项:为了找到分母项,我们只需要计算 𝑀1。但回想一下,模型的最终输出只是 𝑌 = 𝑀𝑋(方程 (16))。因此,可以通过用一个额外的列 1 扩充 𝑋 来找到归一化项,从而得到一个形状为 (T, P + 1) 的张量。

请注意,在这种情况下,核特征映射 𝜓 必须为正,以便总和为正。

SSM 的系统优化

我们描述了 SSM(尤其是 Mamba-2 架构)的几种系统优化,用于大规模高效训练和推理。特别是,我们专注于用于大规模训练的张量并行和序列并行,以及用于高效微调和推理的可变长度序列。

张量并行

张量并行 (TP) (Shoeybi et al. 2019) 是一种模型并行技术,它将每一层(例如,注意力、MLP)拆分到多个加速器(例如 GPU)上运行。这种技术被广泛用于在 GPU 集群上训练大多数大型模型,其中每个节点通常有 4-8 个具有快速网络(例如 NVLink)的 GPU。TP 最初是为 Transformer 架构开发的,将其应用于其他架构并不合适。

序列并行

对于非常长的序列,我们可能需要沿着序列长度维度将输入和激活拆分到不同的 GPU。主要有两种技术:

  1. 用于残差和归一化操作的序列并行 (SP):该技术首先由 Korthikanti 等人 (2023) 提出,它将 TP 中的全约简分解为约简分散和全收集。注意到残差和归一化操作在同一 TP 组中的所有 GPU 上对相同的输入重复执行,SP 通过执行以下操作沿序列长度维度拆分激活:约简分散、残差和归一化,然后全收集。

由于 Mamba-2 架构使用相同的残差和归一化结构,因此 SP 无需修改即可应用。

  1. 用于标记混合操作(注意力或 SSM)的序列并行,也称为“上下文并行”(CP)。已经为注意力层开发了几种技术(例如,环形注意力 (Liu, Yan, et al. 2024; Liu, Zaharia, and Abbeel 2023)),以及复杂的负载均衡技术 (Brandon et al. 2023)。注意力中序列并行的难点在于我们可以将查询和键拆分为块,但每个查询块都需要与键块交互,从而导致通信带宽与工作器数量呈二次方关系。

对于 SSM,我们可以以一种简单的方式拆分序列:每个工作器获取一个初始状态,根据其输入计算 SSM,返回最终状态,并将该最终状态传递给下一个工作器。通信带宽与工作器数量呈线性关系。这种分解与 SSD 算法(图 5)中用于拆分为块/块的块分解完全相同。我们在图 7(右)中说明了这种上下文并行。

可变长度

虽然预训练通常对批次使用相同的序列长度,但在微调或推理期间,模型可能需要处理不同长度的不同输入序列。处理这种情况的一种简单方法是将批次中的所有序列右填充到最大长度,但如果序列长度差异很大,这可能会很低效。对于 Transformer,已经开发出复杂的技术来避免填充并在 GPU 之间进行负载均衡 (Zeng et al. 2022; Y. Zhai et al. 2023),或者将多个序列打包到同一个批次中并调整注意力掩码 (Ding et al. 2024; Pouransari et al. 2024)。对于 SSM,特别是 Mamba,我们可以通过简单地将整个批次视为一个长序列来处理可变序列长度,并避免在各个序列之间传递状态。这相当于简单地将一个序列末尾的标记 𝑡 的 𝐴𝑡 设置为 0,以防止它将信息传递给属于不同序列的标记 𝑡 + 1。

实验验证

我们通过实验评估了 Mamba-2 在对递归模型具有挑战性的合成召回任务(第 9.1 节)和标准语言建模预训练和下游评估(第 9.2 节)上的性能。我们验证了我们的 SSD 算法比 Mamba-1(第 9.3 节)效率更高,并且在中等序列长度下可与优化的注意力相媲美。最后,我们对 Mamba-2 架构中的各种设计选择进行了消融研究(第 9.4 节)。

合成:关联召回

合成关联召回任务一直很受欢迎,用于测试语言模型在其上下文中查找信息的能力。广义上讲,它们涉及向自回归模型馈送键值关联对,然后在模型看到先前看到的键时提示模型生成正确的完成。多查询关联召回 (MQAR) 任务是此任务的一种特定形式,它要求模型记住多个关联 (Arora, Eyuboglu, Timalsina, et al. 2024)。最初的 Mamba 论文报告了相关合成任务的结果,特别是选择性复制 (Gu and Dao 2023) 和归纳头 (Olsson et al. 2022),它们可以被视为更容易的关联召回任务。MQAR 任务也与“电话簿查找”任务密切相关,该任务已被证明对 SSM 等递归模型具有挑战性,因为它们的状态容量有限 (De et al. 2024; Jelassi et al. 2024)。

我们在 (Arora, Eyuboglu, Zhang, et al. 2024) 中提出的具有挑战性的 MQAR 设置版本上进行了比较,使用了更难的任务、更长的序列和更小的模型。我们的基线包括标准的多头 Softmax 注意力以及结合了卷积、局部注意力和线性注意力变体的 Based 架构。

结果如图 8 所示。虽然 Mamba-1 在此任务上表现不佳,但 Mamba-2 在所有设置下都表现良好。令人惊讶的是,即使在控制状态大小 (N = 16) 的情况下,它也明显优于 Mamba-1。(我们不确定架构的哪个方面是主要因素,这仍然是未来工作中需要探索的问题。)此外,此任务验证了状态大小的重要性:从 N = 16 增加到 N = 64 和 N = 256 始终可以提高 MQAR 的性能,因为更大的状态允许记住更多信息(键值对)。

语言建模

按照 LLM 中的标准协议,我们在标准自回归语言建模上训练和评估 Mamba-2 架构,并与其他架构进行比较。我们比较了预训练指标(困惑度)和零样本评估。模型大小(深度和宽度)遵循 GPT3 规范,从 125m 到 2.7B。我们使用 Pile 数据集 (L. Gao, Biderman, et al. 2020),并遵循 Brown 等人 (2020) 中描述的训练方法。这与 Mamba (Gu and Dao 2023) 中报告的设置相同;训练细节见附录 D。

缩放定律

对于基线,我们与 Mamba 及其 Transformer++ 方法 (Gu and Dao 2023) 进行了比较,后者基于 PaLM 和 LLaMa 架构(例如旋转嵌入、SwiGLU MLP、RMSNorm 而不是 LayerNorm、没有线性偏差和更高的学习率)。由于 Mamba 已经证明它优于标准 Transformer 架构(GPT3 架构)以及最近的次二次架构(H3 (Dao, D. Y. Fu, et al. 2023)、Hyena (Poli et al. 2023)、RWKV-4 (B. Peng, Alcaide, et al. 2023)、RetNet (Y. Sun et al. 2023)),为了清晰起见,我们在图中省略了这些架构(有关比较,请参见 Gu and Dao (2023))。

图 9 显示了在标准 Chinchilla (Hoffmann et al. 2022) 协议下,从 ≈ 125𝑀 到 ≈ 1.3𝐵 参数的模型的缩放定律。

下游评估

表 1 显示了 Mamba-2 在一系列流行的下游零样本评估任务上的性能,并与这些规模下最著名的开源模型进行了比较,最重要的是 Pythia (Biderman et al. 2023),后者使用与我们的模型相同的标记器、数据集和训练长度(300B 标记)进行训练。

混合模型:将 SSD 层与 MLP 和注意力相结合

最近和同时进行的研究 (Dao, D. Y. Fu, et al. 2023; De et al. 2024; Glorioso et al. 2024; Lieber et al. 2024) 表明,同时具有 SSM 层和注意力层的混合架构可以提高模型质量,使其优于 Transformer 或纯 SSM(例如,Mamba)模型,尤其是在上下文学习方面。我们探索了将 SSD 层与注意力和 MLP 相结合的不同方式,以了解每种方式的好处。根据经验,我们发现大约 10% 的层是注意力层时性能最佳。将 SSD 层、注意力层和 MLP 相结合也比纯 Transformer++ 或 Mamba-2 效果更好。

SSD 和注意力:我们发现 SSD 和注意力层是互补的:它们本身(例如,在 Mamba-2 架构与 Transformer++ 中)的性能(以困惑度衡量)几乎相同,但 SSD 和注意力层的混合优于纯 Mamba-2 或 Transformer++ 架构。我们在表 2 中展示了一些结果,这些结果是使用 GPT-2 标记器在 Pile 上训练到 7B 标记的 350M 模型(48 层)(参数数量相同、超参数相同、训练和验证集相同)。仅添加几个注意力层就已经产生了显着的改进,并在质量和效率之间取得了最佳平衡。我们假设 SSM 层可以很好地用作一般的序列到序列映射,而注意力层充当检索机制,可以快速引用序列中的先前标记,而不是强迫模型将其所有上下文压缩到其内存(SSM 状态)中。

具有 SSD、MLP 和注意力的混合模型:我们比较了将 SSD 与(门控)MLP 和注意力层相结合的不同方式,并在 Pile 上训练到 300B 标记的 2.7B 规模(64 层)上进行了评估(参数数量相同、超参数相同、训练和验证集相同、数据顺序相同):

  1. Transformer++:32 个注意力层和 32 个门控 MLP,交错。
  2. Mamba-2:64 个 SSD 层。
  3. Mamba-2-MLP:32 个 SSD 和 32 个门控 MLP 层,交错。
  4. Mamba-2-Attention:58 个 SSD 层和 6 个注意力层(位于索引 9、18、27、36、45、56)。
  5. Mamba-2-MLP-Attention:28 个 SSD 层和 4 个注意力层,与 32 个门控 MLP 层交错。

我们在表 3 中报告了 Pile 上的验证困惑度以及零样本评估。总的来说,Transformer++ 和 Mamba-2 模型的质量大致相同。我们看到仅添加 6 个注意力层就显着改善了纯 Mamba-2 模型(以及 Transformer++)。添加 MLP 层会降低模型质量,但可以 (i) 由于 MLP 层的简单性和硬件效率而加快训练和推理速度;(ii) 通过用专家混合替换 MLP 层,更容易升级到 MoE 模型。

速度基准

我们将 SSD 算法的速度与 Mamba 的扫描实现和 FlashAttention-2 进行了基准测试(图 10)。SSD 由于其重新制定以使用矩阵乘法作为子例程,因此可以利用 GPU 上的专用矩阵乘法 (matmul) 单元,也称为张量核。因此,它比 Mamba 的融合关联扫描快 2-8 倍,后者不利用 matmul 单元。由于其序列长度的线性缩放,SSD 从序列长度 2𝐾 开始就比 FlashAttention-2 快。

但是,我们注意到,在短序列长度(例如 2𝐾)下,整个 Mamba-2 模型的训练效率可能不如 Transformer,因为具有 𝐿 层的 Transformer 将具有 𝐿/2 个注意力层和 𝐿/2 个 MLP 层,而 Mamba-2 模型在参数数量相同的情况下将具有 𝐿 个 SSD 层。通常,MLP 层的硬件效率非常高,因为它们由简单的矩阵乘法和逐点线性组成。如第 9.2.3 节所示,也可以将 𝐿/2 个 SSD 层和 𝐿/2 个 MLP 层组合起来,以加快短序列长度下的训练速度。

架构消融研究

块设计

第 7.1 节介绍了 Mamba-2 块,它对 Mamba-1 块进行了一些小的修改,部分原因是为了与注意力联系起来,也是为了提高 Mamba-2 的可扩展性。表 4 对块的这些架构更改进行了消融研究,这些更改发生在核心 SSM 层之外。

消融研究验证了并行投影创建 (𝐴, 𝐵, 𝐶, 𝑋) 可以节省参数,并且性能略好于 Mamba 的顺序投影。更重要的是,这种修改适用于大型模型的张量并行(第 8 节)。此外,额外的归一化层也略微提高了性能。更重要的是,在更大规模的初步实验中观察到它还有助于训练稳定性。

头部结构

第 7.2 节描述了如何将 𝐵、𝐶、𝑋 投影的维度视为类似于多头注意力和多查询注意力概念的超参数。我们还展示了原始的 Mamba 架构如何类似于多值注意力(命题 7.2),这是从状态空间模型的角度自然发展而来的选择,之前没有进行过消融研究。

表 5 对 Mamba-2 架构的多头结构选择进行了消融研究。引人注目的是,我们发现多值和多查询或多键头部模式之间存在很大差异,尽管它们看起来非常相似。请注意,这不能用总状态大小来解释,所有这些模式的总状态大小都相同(等于 HPN 或头部数量、头部维度和状态维度的乘积)。

我们还比较了 𝐶、𝐵、𝑋(类似于 𝑄、𝐾、𝑉)头部的数量相等的多头模式。我们与标准的多头模式以及一种积极共享的模式进行了比较,在积极共享的模式中,它们都只有一个头部。请注意,在后一种情况下,模型仍然有 H 个不同的序列混合器 𝑀,因为每个头部仍然有不同的 𝐴。当参数匹配时,这些多头模式的性能彼此相似,介于 MVA 和 MQA/MKA 模式之间。

注意力核近似

第 7.3 节指出 SSD 如何与线性注意力文献中的思想相结合,例如各种形式的核近似。我们在表 6 中对先前工作建议的几种变体进行了消融研究。这些变体包括 cosFormer (Qin, Weixuan Sun, et al. 2022)、随机特征注意力 H. Peng et al. 2021 和正随机特征 (Performer) (Choromanski et al. 2021)。

我们还对添加归一化项进行了消融研究,类似于标准注意力中 Softmax 函数的分母。我们发现这会给大多数变体带来不稳定性,但会略微提高 ReLU 激活函数 𝜓 的性能。

表 7 还测试了最近提出的改进线性注意力的建议,这些建议涉及扩展特征维度(Based (Arora, Eyuboglu, Zhang, et al. 2024) 和 ReBased (Aksenov et al. 2024))。这些线性注意力扩展旨在使用二次近似来逼近 exp 核。ReBased 还建议用层归一化替换 QK 激活函数;从以 SSM 为中心的观点来看,我们在应用 SSM 函数之前对 (𝐵, 𝐶) 应用归一化。

我们注意到,这种技术已被独立地提出作为 Softmax 注意力的“QK-Norm”(Team 2024)和 Mamba 的“内部归一化”(Lieber et al. 2024)。

总的来说,表 6 和表 7 发现,我们尝试的核近似方法似乎并没有比简单的逐点非线性激活函数 𝜓 有所改进。因此,我们对 Mamba-2 的默认设置使用 𝜓 (𝑥) = Swish(𝑥) 来遵循 Mamba-1,但我们建议完全删除此激活可能是一个更简单的选择,我们没有对其进行广泛测试。

我们强调,SSD 和普通线性注意力在包含 1-半可分掩码 𝐿 方面有所不同,而文献中的各种线性注意力方法都是为了在没有此项的情况下近似 Softmax 注意力而推导出来的;因此,我们的负面结果可能并不意外。

相关工作和讨论

状态空间对偶性框架将 SSM、结构化矩阵和注意力之间的联系联系起来。我们将更深入地讨论 SSD 与这些概念之间更广泛的关系。利用每个观点的思想,我们还提出了一些未来工作中可以扩展 SSD 框架的方向。

状态空间模型

结构化状态空间模型可以沿着以下轴进行表征:

(i) 它是否是时不变的或时变的。
(ii) 系统的维数。
(iii) 递归转换 𝐴 上的结构。

SSD 可以描述为具有 SISO 维度和标量恒等结构的选择性 SSM。

时变性(选择性):最初的结构化 SSM (S4) 是线性时不变 (LTI) 系统 (Gu 2023; Gu, Goel, and Ré 2022),其动机是连续时间在线记忆 (Gu, Dao, et al. 2020; Gu, Johnson, Goel, et al. 2021; Gu, Johnson, Timalsina, et al. 2023)。已经提出了许多结构化 SSM 的变体 (Dao, D. Y. Fu, et al. 2023; Gu, Gupta, et al. 2022; Gupta, Gu, and Berant 2022; Ma et al. 2023; J. T. Smith, Warrington, and Linderman 2023),包括几种放弃递归并专注于 LTI SSM 的卷积表示的变体 (D. Y. Fu et al. 2023; Y. Li et al. 2023; Poli et al. 2023; Qin, Han, Weixuan Sun, B. He, et al. 2023)。

SSD 是一种时变结构化 SSM,也称为 Mamba (Gu and Dao 2023) 中引入的选择性 SSM。选择性 SSM 与 RNN 的门控机制密切相关,包括经典的 RNN,例如 LSTM (Hochreiter and Schmidhuber 1997) 和 GRU (J. Chung et al. 2014),以及更现代的变体,例如 QRNN (Bradbury et al. 2016)、SRU (Lei 2021; Lei et al. 2017)、RWKV (B. Peng, Alcaide, et al. 2023)、HGRN (Qin, Yang, and Zhong 2023) 和 Griffin (Botev et al. 2024; De et al. 2024)。这些 RNN 在参数化方面有所不同,最重要的是缺乏状态扩展。

维数和状态扩展: SSD 的一个重要特征是它是一个单输入单输出 (SISO) 系统,其中输入通道独立处理,这与之前其谱系中的 SSM(S4、H3、Mamba)相同。这导致更大的有效状态大小 ND,其中 N 是 SSM 状态大小(也称为状态扩展因子),D 是标准模型维度。传统的 RNN 要么具有 N = 1,要么是具有密集 𝐵、𝐶 矩阵的多输入多输出 (MIMO),这两者都会导致状态更小。虽然 MIMO SSM 已被证明在某些领域中效果良好 (Lu et al. 2023; Orvieto et al. 2023; J. T. Smith, Warrington, and Linderman 2023),但 Mamba 表明状态扩展对于信息密集型领域(例如语言)至关重要。SSD 的主要优势之一是允许更大的状态扩展因子,而不会降低模型速度。此后,许多后续工作都采用了状态扩展(第 10.4 节)。

结构:与之前的结构化 SSM 相比,SSD 的主要限制在于状态转换 𝐴𝑡 的表达能力。我们注意到,更通用的 SSM(例如对角线 𝐴𝑡 的情况)具有与 SSD 相同的理论效率,但对硬件不太友好。这是因为对偶二次形式失去了类似注意力的解释,并且变得更难计算。因此,与 Mamba 相比,SSD 仅在对角线 𝐴𝑡 的限制性稍强的形式上有所不同,并以这种表达能力换取了更高的硬件效率(以及易于实现)。

我们假设可以改进我们的结构化矩阵算法,以改进到通用对角线 SSM 的情况。

结构化矩阵

状态空间对偶性的第一个观点是将这些模型视为矩阵序列变换或“矩阵混合器”:可以表示为沿序列维度 T 进行矩阵乘法(通过 T × T 矩阵)的序列变换(定义 2.1)。

之前已经提出了几种这样的矩阵混合器,其中主要的变化轴是矩阵的表示。这些包括 MLP-Mixer (Tolstikhin et al. 2021)(非结构化矩阵)、FNet (Lee-Thorp et al. 2021)(傅里叶变换矩阵)、M2 (Dao, B. Chen, et al. 2022; Dao, Gu, et al. 2019; Dao, Sohoni, et al. 2020; D. Fu et al. 2024)(蝴蝶/帝王蝶矩阵)、Toeplitz 矩阵 (Poli et al. 2023; Qin, Han, Weixuan Sun, B. He, et al. 2023),甚至更奇特的结构 (De Sa et al. 2018; Thomas et al. 2018)。

一个重要的特征是,高效(次二次)矩阵序列变换正是那些具有结构化矩阵混合器的变换。SSD 框架的核心结果是将 SSM 视为具有特定结构的矩阵混合器——半可分矩阵(第 3 节)。然后,线性与二次对偶性采用结构化矩阵乘法与朴素矩阵乘法的形式。

结构矩阵表示通过特定半可分矩阵的块分解导致了我们高效的 SSD 算法(第 6 节)。我们注意到,半可分矩阵在科学计算文献中得到了很好的研究,结合这些思想可能是进一步改进状态空间模型的一个有希望的途径。我们还建议,关注矩阵混合器观点可以为序列模型带来更多富有成效的方向,例如设计有原则的非因果 Mamba 变体,或者找到通过分析其矩阵变换结构来表征和弥合 Softmax 注意力和次二次模型之间差距的方法。

(线性)注意力

与标准(因果)注意力相比,SSD 只有两个主要区别。

首先,SSD 不使用标准注意力的 Softmax 激活 (Bahdanau, Cho, and Bengio 2015; Vaswani et al. 2017),这是注意力具有二次复杂性的原因。当删除 Softmax 时,可以通过线性注意力框架以线性缩放计算序列 (Katharopoulos et al. 2020)。

其次,SSD 将 logits 矩阵乘以一个输入相关的 1-半可分掩码。因此,可以将此掩码视为替换标准注意力中的 Softmax。

此半可分掩码也可以视为提供位置信息。元素 𝑎𝑡 充当 RNN 意义上的“门”,或“选择”机制(参见 Mamba 论文中的讨论),它们的累积乘积 𝑎 𝑗:𝑖 控制位置 𝑖 和 𝑗 之间允许的交互量。位置嵌入(例如正弦 (Vaswani et al. 2017)、AliBi (Press, N. Smith, and Lewis 2022) 和 RoPE (Su et al. 2021))是 Transformer 的重要组成部分,通常被视为启发式方法,SSD 的 1-SS 掩码可以被视为一种更有原则的相对位置嵌入形式。我们注意到 GateLoop (Katsch 2023) 也同时提出了这种观点。

状态空间对偶性的第二个观点是我们更通用的结构化掩码注意力 (SMA) 框架的特例,其中对偶性表现为对简单的 4 路张量收缩的不同收缩顺序。SMA 是线性注意力的强泛化,它也比 SSD 通用得多;其他形式的结构化掩码可能会导致更多具有不同于 SSD 属性的高效注意力变体。

除了导致新模型之外,这些与注意力的联系还可以为理解 SSM 提供其他方向。例如,我们很好奇注意力汇聚现象 (Darcet et al. 2024; Xiao et al. 2024) 是否存在于 Mamba 模型中,以及更广泛地说,可解释性技术是否可以转移到 SSM (Ali, Zimerman, and Wolf 2024)。

最后,已经提出了许多其他线性注意力变体 (Arora, Eyuboglu, Timalsina, et al. 2024; Arora, Eyuboglu, Zhang, et al. 2024; Choromanski et al. 2021; H. Peng et al. 2021; Qin, Han, Weixuan Sun, Dongxu Li, et al. 2022; Qin, Weixuan Sun, et al. 2022; Schlag, Irie, and Schmidhuber 2021; Zhang et al. 2024; Zheng, C. Wang, and Kong 2022)(有关其中几种变体的描述,请参见第 4.1.3 节),我们预计许多技术可以转移到 SSM(例如第 7.3 节)。

我们强调,SSD 不会泛化标准 Softmax 注意力,或对注意力核矩阵的任何其他没有有限特征映射 𝜓 的变换。与通用注意力相比,SSD 的优势在于具有可控的状态扩展因子 N,可以压缩历史记录,而二次注意力的缓存则包含整个历史记录,其大小与序列长度 T ≫ N 成正比。同时进行的研究已经开始研究这些表示的权衡,例如在复制和上下文学习任务上 (Akyürek et al. 2024; Grazzi et al. 2024; Jelassi et al. 2024; Park et al. 2024)。我们注意到,Mamba-2 在其中一些能力上显着改进了 Mamba(例如,如第 9.1 节中 MQAR 结果所示),但还有更多需要了解的地方。

相关模型

最后,我们重点介绍了越来越多的最近和同时进行的工作,这些工作开发了与 Mamba 和 Mamba-2 非常相似的序列模型。

  • RetNet (Y. Sun et al. 2023) 和 TransNormerLLM (Qin, Dong Li, et al. 2023) 使用衰减项而不是累积和来泛化线性注意力,并提出了对偶并行/递归算法以及混合“分块”模式。这些算法可以看作是 𝐴𝑡 时不变(对于所有 𝑡 都是常数)的 SSD 实例;在 SMA 解释中,掩码矩阵 𝐿 将是一个衰减矩阵 𝐿𝑖,𝑗 = 𝛾𝑖 − 𝑗。这些模型在架构上也有各种不同。例如,由于它们是从以注意力为中心的视角推导出来的,因此它们保留了多头注意力 (MHA) 模式;由于 Mamba-2 是从以 SSM 为中心的模式推导出来的,因此它保留了多值注意力 (MVA) 或多扩展 SSM (MES) 模式,我们证明这种模式更好(第 9.4 节)。
  • GateLoop (Katsch 2023) 同时提出了使用输入相关的衰减因子 𝐴𝑡,并开发了与 SSD 中相同的对偶二次形式,他们称之为“代理注意力”形式。
  • 门控线性注意力 (GLA) (Yang et al. 2024) 提出了一种具有数据相关门的线性注意力变体,以及用于计算分块模式和硬件感知实现的高效算法。
  • HGRN (Qin, Yang, and Zhong 2023) 引入了一种具有输入相关门的 RNN,在 HGRN2 (Qin, Yang, Weixuan Sun, et al. 2024) 中通过结合状态扩展对其进行了改进。
  • Griffin (De et al. 2024) 和 RecurrentGemma (Botev et al. 2024) 表明,具有输入相关门控的 RNN 与局部注意力相结合,可以与强大的现代 Transformer 非常有竞争力。Jamba 还表明,将 Mamba 与几层注意力相结合在语言建模方面表现出色 (Lieber et al. 2024)。
  • xLSTM (Beck et al. 2024) 通过采用状态扩展的思想以及其他门控、归一化和稳定技术来改进 xLSTM。
  • RWKV(-4) (B. Peng, Alcaide, et al. 2023) 是一种基于不同线性注意力近似(无注意力 Transformer (S. Zhai et al. 2021))的 RNN。最近,通过采用选择性和状态扩展的思想,将其改进为 RWKV-5/6(Eagle 和 Finch)架构 (B. Peng, Goldstein, et al. 2024)。

结论

我们提出了一个基于结构化矩阵的理论框架,该框架弥合了 SSM 和注意力变体之间的概念差距。该框架为最近的 SSM(例如 Mamba)如何在语言建模方面与 Transformer 表现一样好提供了见解。此外,我们的理论工具通过连接两方面的算法和系统进步,为改进 SSM(以及潜在的 Transformer)提供了新思路。作为演示,该框架指导我们设计了一种新的架构 (Mamba-2),该架构位于 SSM 和结构化注意力的交汇处。

致谢

我们感谢 Angela Wu 就如何以数值稳定的方式有效计算 Δ 的梯度提出的建议。我们感谢 Sukjun Hwang 和 Aakash Lahoti 在 MQAR 实验中提供的帮助。

参考文献

https://arxiv.org/pdf/2405.21060

0 0 投票数
Article Rating
订阅评论
提醒
0 评论
最旧
最新 最多投票
内联反馈
查看所有评论
0
希望看到您的想法,请您发表评论x