破解AI模型速度瓶颈:一种全新的“分组查询注意力”方法

288次阅读
没有评论

你是否曾经对人工智能模型的运算速度感到不耐烦,同时又希望它能保持高质量的预测结果?这可能听起来像是一个无法两全的问题,但科研人员们并没有停下探索的脚步。今天,我们要介绍的这篇研究报告,就给出了一个行之有效的解决方案。这篇研究名为 “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints”,由来自 Google Research 的团队所撰写。他们提出了一种称为“分组查询注意力(Grouped-query attention, GQA)”的新方法,旨在解决 Transformer 模型中的一个关键问题,即如何在保持预测质量的同时,提高模型的运算速度。

首先,让我们理解一下这个问题的背景。在 Transformer 模型中,一个关键的计算过程就是自回归解码器推理。这个过程需要大量的内存带宽来加载解码器权重和所有注意力键值,这就大大限制了模型的运算速度。为了解决这个问题,研究者们提出了多查询注意力(Multi-query attention, MQA)方法,它只使用一个键值对来大幅度提高解码器推理的速度。然而,MQA 方法可能会导致预测质量下降,而且也不太适合用于训练单独的模型以提高推理速度。

在这样的背景下,Google Research 的团队提出了两个重要的贡献。首先,他们发现,可以使用少量的原始训练计算来将具有多头注意力(Multi-head attention, MHA)的语言模型检查点进行升级训练,使其能够使用 MQA,这是一种非常成本有效的方法,可以同时获得高速的 MQA 和高质量的 MHA 检查点。其次,他们提出了分组查询注意力(GQA)的概念,这是一种在多头注意力和多查询注意力之间的插值方法,它为每组查询头部共享一个键和值头部。

GQA 的工作原理是将查询头部分成若干组,每组共享一个键头和值头。具有 G 组的 GQA 被称为 GQA-G。GQA-1(具有一个组,因此具有一个键和值头)等同于 MQA,而具有等于头部数量的组的 GQA- H 等同于 MHA。通过使用中间数量的组,GQA 可以产生一个质量比 MQA 高,但速度比 MHA 快的插值模型。此外,对于大型模型,GQA 的优势更加明显,因此,我们期待 GQA 能在大型模型中提供一个特别好的权衡方案。

在实验部分,研究者们使用了基于 T5.1.1 架构的所有模型,并对 T5 Large 和 XXL 的多头注意力版本,以及使用多查询和分组查询注意力的升级版 T5 XXL 进行了主要实验。实验结果表明,使用 GQA 的 T5-XXL 模型在各种不同的数据集上,包括 CNN/Daily Mail, arXiv, PubMed, MediaSum, 和 MultiNews 等新闻摘要数据集,以及 WMT 英德翻译数据集和 TriviaQA 问答数据集上,都保持了与多头注意力模型相近的质量,同时又具有与多查询注意力模型相近的速度。

在 AI 领域,我们一直在寻找提高效率和质量的方法,而 GQA 的出现无疑为我们提供了一个新的可能。它不仅提高了模型的运算速度,而且还成功地保持了预测的质量。这使得 GQA 成为了提高 AI 模型性能的一种有力工具,我们有理由期待,这种方法将在未来的 AI 应用中发挥更大的作用。

总的来说,这项研究的重要性在于,它不仅提供了一种提高 AI 模型速度的有效方法,而且这种方法还能保持模型的预测质量。这使得我们可以在实际应用中实现更快、更准确的 AI 模型,从而在各种场景中提供更好的服务。

这就是今天的分享,希望你们能从中获取到有用的信息。我们将继续关注更多的人工智能研究,并与大家分享。感谢你们的倾听,我们下次见!

正文完
 
评论(没有评论)