《思考的快与慢:用蒸馏推理器扩展推理计算能力》

🧠 引言:计算资源与推理能力的博弈

在人工智能的发展历程中,我们一直面临着一个核心问题:如何在有限的计算资源下获得最佳的模型性能?就像人类思考有快慢之分,大型语言模型(LLMs)的推理能力也存在着类似的权衡。当我们需要解决复杂问题时,是选择一个庞大但速度较慢的模型,还是选择一个轻量级但能快速生成多个答案的模型?这个问题引发了一场关于AI架构设计的深刻思考。

近期研究表明,通过在测试阶段增加计算资源,大型语言模型的性能可以得到显著提升。一种常见策略是生成多个思维链(Chain-of-Thought, CoT)轨迹,并通过各种选择机制聚合它们的输出。这引发了一个根本性问题:复杂度较低的模型能否利用其卓越的生成吞吐量,在固定计算预算下超越同等规模的Transformer模型?

为了解答这个问题并克服缺乏强大次二次推理器的困境,研究团队从预训练的Transformer模型中蒸馏出了纯Mamba模型和混合Mamba模型。这些模型仅在80亿个token上训练,就在数学推理数据集上展现出强大的性能和良好的扩展性,同时在处理大批量和长序列时比传统模型快得多。尽管在零样本性能上因蒸馏而有所损失,但在固定时间预算下,纯Mamba和混合Mamba模型都能将其覆盖率和准确率扩展到超过其Transformer教师模型,为扩展推理计算开辟了新方向。

🔍 研究背景:推理计算的扩展挑战

大型语言模型的推理能力近年来取得了显著进步,这在很大程度上是由计算资源的扩展驱动的。提升”推理”性能的一个关键技术是在产生最终答案之前使用中间推理步骤,即所谓的思维链(CoT)。基于此,许多测试时计算技术通常涉及生成多个CoT并选择最佳的一个。即使是简单的策略,如多数投票,也能出奇地有效。

然而,这些测试时计算技术给LLM系统带来了重大挑战。生成长CoT序列或大批量完成任务会对内存和计算资源提出巨大需求。特别是Transformer模型,由于其线性内存扩展和在生成过程中受内存限制的特性,在处理这类工作负载时尤其困难。这引发了一个重要问题:我们应该如何优化模型架构以最好地扩展测试时计算?特别是,具有更快更高效生成能力的替代架构能否在固定计算预算下超越当前的LLM?

🧩 研究方法:蒸馏学生推理器

为了探索次二次架构的推理能力,研究团队从预训练的Transformer中蒸馏知识到混合和纯Mamba模型中。他们开发了蒸馏特定推理技能到这些架构的方法,并针对多个思维链(CoT)完成任务对模型进行基准测试,提供了在固定计算和内存约束下性能的全面分析。

研究团队的方法推进了现有模型建立的帕累托前沿,在效率和推理能力之间取得了更好的权衡。他们的蒸馏纯Mamba和混合次二次推理器能够在大多数时间预算内,在MATH和GSM8K数学推理任务的覆盖率和准确率上超越其Transformer教师,使用2.5倍更少的推理时间达到相同的质量。

🔬 蒸馏方法详解:从Transformer到Mamba的知识转移

蒸馏到Llamba

研究团队修改了Bick等人(2024)引入的MOHAWK蒸馏程序。MOHAWK由三个阶段组成:

  1. 矩阵定向:将Mamba-2模型的SSM矩阵混合器与教师的自注意力矩阵对齐,通过最小化两个矩阵之间的距离。
  2. 隐藏状态对齐:匹配学生和教师层的隐藏状态输出。
  3. 权重转移和知识蒸馏:转移剩余的未优化参数,如MLP、嵌入和规范化,并使用学生和教师logits上的蒸馏损失微调完整的端到端学生模型。

他们的纯Mamba蒸馏模型Llamba-1B和Llamba-4B分别从各自的教师模型中使用调整后的MOHAWK蒸馏方法蒸馏而来,每个模型仅使用总共80亿个token。

蒸馏到MambaInLlama

对于混合模型,研究团队修改了Wang等人(2025)提出的协议,以蒸馏一些特定能力。在蒸馏过程中,Q. K、V和O的线性投影使用相应的C、B、X和O的线性投影进行初始化。新层中唯一额外学习的参数是采样率Δ和动态A。这些新参数将通过离散化函数控制构建的Mamba。

与Wang等人(2025)不同,他们在单轮中用Mamba层替换注意力层,并微调整个模型。对于蒸馏,他们采用token级KL散度。学生模型的完整概率分布p(·; θ)被训练以与教师模型的完整分布p(·; θT. 对齐,通过在位置t处所有可能的下一个token上最小化KL散度。

蒸馏后改进性能

研究表明,通过在蒸馏后进行一些监督微调(SFT),可以提高蒸馏模型的准确率和覆盖率。从蒸馏的MambainLlama-1B和3B开始,他们使用来自OpenMathInstruct-2的80亿个token对模型进行了两个epoch的微调。蒸馏模型在覆盖率和准确率方面都取得了令人印象深刻的性能,甚至超过了原始的Llama模型。

📊 扩展推理时间计算:多样化思维链的力量

研究团队通过使用蒸馏模型生成多个CoT来解决一组数学问题,从而扩展了测试时计算。系统提示包含有关如何正确格式化响应的说明。模型输出被解析以提取最终解决方案,然后与真实情况进行比较。这种方法使他们能够评估模型在多次尝试中生成正确解决方案的性能。

他们使用两个主要指标评估模型:覆盖率和准确率。覆盖率通常被称为pass@k指标,其中k表示每个问题的样本数。这个指标估计k个样本中至少存在一个正确解决方案的概率。对于准确率,他们使用多种聚合策略,包括多数投票和加权Best-of-N. 选择的答案是具有最高奖励模型分数总和的答案)。

🚀 实验结果:速度与准确性的突破

推理时间结果

实验结果表明,蒸馏模型的速度比各自的Llama 1B和3B基线快3.7倍和4.2倍。此外,MambaInLlama和Llamba更节省内存,因此可以运行更大的批次。在一项实验中,这些模型可以容纳512的批次,而Llama-3B则返回内存不足错误。研究团队还注意到MambaInLlama模型比Llamba稍快,推测这是因为MambaInLlama的SSM状态大小为16,而Llamba使用更大的SSM状态大小64。

推理任务结果

在MATH和GSM8K数据集上的实验表明,蒸馏模型能够像教师一样覆盖。当观察覆盖率随生成数量k增加的扩展时,蒸馏模型与其教师的覆盖率非常接近,只观察到轻微的退化。当将覆盖率绘制为时间预算的函数时,研究团队发现他们的蒸馏模型在快速生成正确答案方面表现出色。通过在相同时间预算内生成更多完成,MATH和GSM8K的整体帕累托前沿在很大程度上由蒸馏模型主导,其中纯Mamba和混合Mamba推理器能够在几乎一半的时间内达到与各自教师相同程度的覆盖率。

在固定时间预算下,蒸馏模型也能实现有竞争力的准确率。与覆盖率类似,蒸馏模型更轻、更快的批量推理,允许更多生成,在几个完成规模上产生了更好的准确率/时间帕累托前沿。

有趣的是,虽然比较类似大小的模型表明较大的时间预算由教师模型主导,但研究团队观察到较大的蒸馏模型可以提供比较小的基线更好的准确率,同时仍然更快。例如,虽然Llama-1B在较大时间预算下提供比MambaInLlama-1B更好的准确率,但MambaInLlama-3B接替MambaInLlama-1B的位置,为推理时间提供比Llama-1B更好的准确率。

更大的学生比更小的教师更快更好

次二次模型引起越来越多兴趣的核心驱动力是它们的计算效率。这种特性使较大尺寸(3B规模)的蒸馏模型能够比较小的Transformer(1B. 更快地生成样本。MambaInLlama-3B和Llamba-4B模型在覆盖率和准确率上优于较慢的Llama-1B基线,同时更快。这些源于底层架构的推理加速允许在时间受限的环境中使用更大、更有能力的模型,尽管它们的参数数量增加。

较小模型具有出色的覆盖率

当关注覆盖率时,研究团队观察到帕累托前沿的大部分被1B模型占据。将这些结果与多数投票准确率结果进行对比,其中1B和3B模型之间的差距更为显著。一种解释是,虽然较小的模型有能力生成正确的答案,但在有限数量的样本内生成正确答案的概率随着模型大小的增加而增加。这一发现对于选择涉及形式语言的任务的模型大小有影响,在这些任务中答案容易验证,如编码和数学证明。在这些应用中,覆盖率最重要,与较大模型相比,较小模型可能因其更好的时间/覆盖率效率而受到青睐。

🔮 结论与未来展望

研究团队调查了较低复杂度的模型是否能利用其卓越的生成吞吐量,在固定计算预算下超越类似规模的Transformer。他们专注于可以扩展测试时计算以提高性能的推理任务。通过广泛的实验,他们在1B和3B规模上蒸馏了纯Mamba和混合Mamba模型,并在数学推理基准上评估了它们的推理能力,其中次二次模型快速生成多个完成的能力使它们能够在增加推理计算时利用其扩展特性。当固定内存和/或计算时,他们的模型在大多数时间预算上实现了比其Transformer教师对应物更好的覆盖率和准确率。

这些发现突显了Mamba和其他注意力替代方案作为Transformer的强大替代品的潜力,特别是对于从可扩展推理计算中受益的任务。研究团队希望这项工作能激发未来在预训练次二次推理器和进一步探索其推理扩展特性方面的工作。需要更多研究来确定跨架构蒸馏推理能力的最佳方式,因为性能仍然对数据和蒸馏技术高度敏感。此外,由于蒸馏模型展示了卓越的覆盖率,开发更好的奖励模型以更好地识别正确答案可以缩小准确率差距。最后,对会话和主观任务的推理计算扩展的进一步研究将开启一个场景,在这个场景中,更轻的次二次模型可以在性能和速度方面取得更大的提升。

📚 参考文献

  1. Bick, A. , Wang, J., Paliotta, D., Dao, T., & Gu, A. (2024). MOHAWK: A General-Purpose State Space Model Distillation Framework. arXiv preprint.
  2. Wang, J. , Paliotta, D., Pagliardini, M., Li, K. Y., Bick, A., Kolter, J. Z., Gu, A., Fleuret, F., & Dao, T. (2025). Thinking Slow, Fast: Scaling Inference Compute with Distilled Reasoners. arXiv preprint.
  3. Wei, J. , Wang, X., Schuurmans, D., Bosma, M., Ichter, B., Xia, F., Chi, E., Le, Q., & Zhou, D. (2023). Chain-of-Thought Prompting Elicits Reasoning in Large Language Models. arXiv preprint.
  4. Gu, A. , & Dao, T. (2024). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv preprint.
  5. Snell, J. , Uesato, J., Huang, J., Sezener, E., Huang, P., Pasula, H., Agapiou, J., Chisholm, A., Borgeaud, S., Brock, A., Glaese, A., Cai, M., Rauh, N., Wainwright, M., Cabi, S., Quan, J., Jia, N., Bratko, A., Lazaridou, A., & Irving, G. (2024). Scaling Process Supervision for Mathematical Reasoning. arXiv preprint.

评论

发表回复

人生梦想 - 关注前沿的计算机技术 acejoy.com 🐾 步子哥の博客 🐾 背多分论坛 🐾 知差(chai)网