REFRAG: 重新思考基于RAG的解码
简介
REFRAG(REpresentation For RAG)是一种针对检索增强生成(RAG)应用的高效解码框架。通过利用RAG上下文中固有的稀疏性和块对角注意力模式,REFRAG实现了显著的性能提升:
- 通过压缩、感知和扩展上下文表示来显著减少内存使用和推理延迟
- 利用RAG上下文中的稀疏性和块对角注意力模式,消除不必要的计算
- 支持在任意位置压缩令牌块,同时保持解码器的自回归性质
- 通过强化学习策略智能选择哪些上下文块需要扩展为原始令牌
问题背景与挑战
大型语言模型(LLMs)在利用外部知识增强多轮和代理应用中的响应方面表现出色,如检索增强生成(RAG)。然而,处理长上下文输入带来了显著的系统延迟和大量内存需求:
- 低效的令牌分配:RAG上下文通常包含稀疏信息,许多检索到的段落信息量低且在多次推理中重复使用。为所有令牌分配内存/计算是不必要的浪费
- 浪费的编码和其他信息:RAG中的检索过程已经预处理了上下文块,它们的编码和与查询的其他相关性已经可用,但在解码过程中被丢弃
- 异常结构和稀疏注意力:由于多样性和去重等操作,解码过程中的大多数上下文块不相关,导致块间交叉注意力主要为零
这些特性表明,在RAG解码过程中,对上下文的大部分计算是不必要的,可以在对性能影响最小的情况下被消除。因此,需要专门针对RAG系统的优化方法,而不是将其视为通用的LLM推理问题。
REFRAG核心原理与架构
REFRAG的核心思想是利用RAG上下文中的稀疏性和块对角注意力模式,通过压缩、感知和扩展上下文表示来显著减少内存使用和推理延迟。
使用轻量级编码器将检索到的上下文块压缩为紧凑的嵌入表示,而不是使用原始令牌作为输入
通过强化学习策略智能选择哪些上下文块需要扩展为原始令牌,哪些可以保持压缩状态
将压缩的块嵌入和选定的原始令牌输入到解码器,生成响应
与现有方法相比,REFRAG具有以下优势:
- 缩短了解码器的输入长度,提高了令牌分配效率
- 能够重用检索过程中预计算的块嵌入,消除了冗余计算
- 降低了注意力计算复杂度,现在与块数量而非上下文中的令牌数量成二次方关系
- 支持在任意位置压缩令牌块,同时保持解码器的自回归性质
# REFRAG核心处理流程
def refrag_decode(question, context_chunks, encoder, decoder, rl_policy):
# 步骤1: 压缩 - 使用编码器处理所有上下文块
chunk_embeddings = [encoder(chunk) for chunk in context_chunks]
# 步骤2: 感知 - 使用RL策略选择需要扩展的块
selected_chunks = rl_policy.select_chunks(chunk_embeddings, question)
# 步骤3: 扩展 - 准备解码器输入
decoder_input = []
decoder_input.extend(question_tokens) # 添加问题令牌
for i, chunk in enumerate(context_chunks):
if i in selected_chunks:
# 扩展选定的块为原始令牌
decoder_input.extend(chunk)
else:
# 使用压缩的块嵌入
decoder_input.append(chunk_embeddings[i])
# 生成响应
response = decoder.generate(decoder_input)
return response
方法论
持续预训练(CPT)方法
为了对齐编码器和解码器,REFRAG使用下一段落预测任务进行持续预训练。具体来说,对于每个数据点,它包含s + o = T个令牌,用于CPT以准备模型利用块嵌入的下游任务。
重建任务和课程学习
为了确保CPT阶段的成功,REFRAG提出了一种包含重建任务和课程学习方法的训练方案:
- 重建任务:将前s个令牌x1:s输入编码器,学习在解码器中重建令牌x1:s。在此任务中,冻结解码器模型,仅训练编码器和投影层
- 课程学习:逐步增加任务难度,使模型能够逐渐有效地获取复杂技能。对于重建任务,训练从重建单个块开始,然后逐渐增加难度
选择性压缩的强化学习方法
REFRAG引入了选择性令牌压缩,扩展重要的上下文块以改进答案预测。一个强化学习策略,以下一段落预测困惑度作为负奖励,决定哪些块保留其原始形式。编码器和解码器经过微调以处理压缩和未压缩块的混合输入。
# 选择性压缩的强化学习策略
class RLCompressionPolicy:
def __init__(self):
self.transformer = TwoLayerTransformer()
def select_chunks(self, chunk_embeddings, query):
# 使用transformer网络评估每个块的重要性
logits = self.transformer(chunk_embeddings)
# 使用GRPO风格基线减少方差
selected_chunks = []
for i in range(len(chunk_embeddings)):
if self.should_expand(logits[i], selected_chunks):
selected_chunks.append(i)
return selected_chunks
def should_expand(self, logit, selected_chunks):
# 使用mask机制防止重复选择
if i in selected_chunks:
return False
# 基于logit和已选择块决定是否扩展
probability = softmax(logit - mask)
return sample_bernoulli(probability)
实验结果与性能分析
REFRAG在多种长上下文任务上进行了严格验证,包括RAG、多轮对话和长文档摘要,涵盖了广泛的数据集。实验结果确认,REFRAG在各种上下文大小下,与LLaMA模型和其他最先进的基线相比,提供了显著的速度提升,且准确性没有损失。
在强检索器设置下,使用10个检索段落,REFRAG匹配LLaMA的性能,同时实现5.26倍的TTFT加速。在同等延迟条件下(REFRAG使用8个段落vs LLaMA使用1个段落),REFRAG在16个RAG任务上平均获得1.22%的提升。
在弱检索器设置下,使用10个段落,REFRAG相比LLaMA性能提升0.71%,TTFT加速5.26倍。在同等延迟条件下,REFRAG在16个RAG任务上平均获得1.93%的提升。
方法 | TTFT加速 | 内存使用 | 上下文扩展 | 任意位置压缩 |
---|---|---|---|---|
LLaMA | 1× | 高 | 不支持 | 不支持 |
CEPE | 8.26× | 中 | 有限 | 不支持 |
REFRAG | 30.85× | 低 | 16× | 支持 |
应用场景与未来展望
REFRAG的高效解码框架特别适用于需要处理大量上下文信息且对延迟敏感的应用场景:
在Web规模搜索等RAG应用中,REFRAG能够显著减少推理延迟,同时保持或提高答案质量,特别是在检索器性能较弱的情况下表现更佳。
在知识密集型多轮对话中,REFRAG能够处理更长的对话历史,避免因上下文窗口限制而丢失关键信息,从而提高回答质量。
在长文档摘要任务中,REFRAG能够处理更长的输入文档,提取更全面的信息,生成更准确的摘要,同时保持较低的推理延迟。
未来工作包括探索更高效的压缩算法、扩展到更多模态、优化强化学习策略以及将REFRAG与其他长上下文优化技术结合,进一步提升性能。
REFRAG为在延迟敏感、知识密集型应用中部署大型语言模型提供了一个实用且可扩展的解决方案,特别强调了专门处理基于RAG的系统的重要性,并为高效大上下文LLM推理开辟了新的方向。