DeepSeek-V2中的MLA详解

DeepSeek-V2是DeepSeek团队最新发布的MoE(Mixture of Experts)架构的LLM(大型语言模型)底座。该模型拥有236B的总参数量和21B的每个token激活参数量,支持128K tokens的上下文长度。DeepSeek-V2的一个核心创新点就是Multi-head Latent Attention(MLA)。

Multi-head Latent Attention(MLA)简介

MLA对传统Transformer中的多头注意力机制(MHA)进行了改进,主要目标是:

  1. 降低推理时KV Cache的存储开销;
  2. 缓解GQA(Grouped-Query Attention)和MQA(Multi-Query Attention)等方法导致的模型性能损耗。

标准的MHA结构

在标准的MHA结构中,每个token的query、key和value通过参数矩阵映射得到,并分割成多个注意力头。每个头独立计算注意力权重并得到输出,这个过程虽然能捕捉丰富的上下文信息,但在推理时需要缓存大量的KV Cache。

MLA如何改进?

MLA通过对keys和values进行低秩联合压缩来降低KV Cache:

  1. 低秩Key-Value联合压缩
    [
    \mathbf{c}_t^{KV} = W^{DKV} \mathbf{h}_t
    ]
    [
    \mathbf{k}_t^C = W^{UK} \mathbf{c}_t^{KV}
    ]
    [
    \mathbf{v}_t^C = W^{UV} \mathbf{c}_t^{KV}
    ]
    其中,(\mathbf{c}_t^{KV})表示压缩后的隐向量,(W^{DKV})是降维映射矩阵,(W^{UK})和(W^{UV})是升维映射矩阵。在推理时,只需要缓存隐向量(\mathbf{c}_t^{KV}),显著减少了KV Cache的容量。
  2. Queries的低秩压缩
    [
    \mathbf{c}_t^Q = W^{DQ} \mathbf{h}_t
    ]
    [
    \mathbf{q}_t^C = W^{UQ} \mathbf{c}_t^Q
    ]
    这样即便不能减少KV Cache,但可以降低训练过程中的激活内存。

代码实现

以下是MLA在DeepSeek-V2中的Python代码实现片段:


class DeepSeekV2Attention(nn.Module):
def init(self, config: DeepSeekV2Config, layer_idx: Optional[int] = None):

self.w_dq = nn.Linear(self.hidden_size, config.q_lora_rank, bias=config.attention_bias)
self.w_uq = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False)
self.w_dkv = nn.Linear(self.hidden_size, self.dc, bias=config.attention_bias)
self.w_uk = nn.Linear(self.dc, self.num_heads * self.q_head_dim, bias=False)
self.w_uv = nn.Linear(self.dc, self.num_heads * self.q_head_dim, bias=False)

def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None,
            output_attentions: bool = False, use_cache: bool = False, **kwargs):
    bsz, q_len, _ = hidden_states.size()

    q = self.w_uq(self.q_a_layernorm(self.w_dq(hidden_states))).view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
    kv_seq_len = q.size(-2)
    if past_key_value is not None:
        kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

    compressed_kv = self.w_dkv(hidden_states)
    if past_key_value is not None:
        compressed_kv = past_key_value.update(compressed_kv)

    k = self.w_uk(compressed_kv).view(bsz, -1, self.num_heads, self.q_head_dim).transpose(1, 2)
        v = self.w_uv(compressed_kv).view(bsz, -1, self.num_heads, self.q_head_dim).transpose(1, 2)

        attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.softmax_scale
        attn_weights = nn.functional.softmax(attn_weights, dim=-1)

        if output_attentions:
            outputs = (attn_weights,)
        else:
            outputs = ()

        attn_output = torch.matmul(attn_weights, v)
        attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size)

        attn_output = self.out_proj(attn_output)
        outputs = (attn_output,) + outputs

        if use_cache:
            outputs = outputs + (past_key_value,)

        return outputs
```


结论
DeepSeek-V2通过引入Multi-head Latent Attention(MLA)结构,成功优化了传统的多头注意力机制(MHA),在保证模型性能的同时,显著降低了推理时KV Cache的存储开销。这不仅提高了模型的效率,也为未来的大模型架构设计提供了新的思路。

MLA的实现通过对queries、keys和values进行低秩压缩,减少了存储需求,缓解了因GQA和MQA方法导致的性能损耗。这种创新在深度学习模型的设计中具有重要的参考价值。

如果你对于DeepSeek-V2的MLA结构有更多的兴趣,建议查看其开源代码和详细文档,以便深入理解其工作机制和实现细节。

0 0 投票数
Article Rating
订阅评论
提醒
0 评论
最多投票
最新 最旧
内联反馈
查看所有评论
0
希望看到您的想法,请您发表评论x