Warfield, D. (2024). Multi-Headed Self Attention — By Hand. Intuitively and Exhaustively Explained. https://iaee.substack.com/p/multi-headed-self-attention-by-hand✅
Vaswani, A. , Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., … & Polosukhin, I. (2017). Attention is all you need. Advances in neural information processing systems, 30.✅
1. 引言
在现代人工智能领域,多头自注意力机制(Multi-Headed Self Attention, MHSA)可以说是最重要的架构范式之一。它是Transformer模型的核心组件,而Transformer又是当前最先进的大型语言模型的基础架构。本文将深入浅出地解析多头自注意力机制的工作原理,通过手动计算的方式,让读者对其内部运作有一个直观而全面的理解。
2. 多头自注意力机制的背景
在深入MHSA之前,我们先简要回顾一下自然语言处理(NLP)领域的相关发展历程。早期的NLP模型主要依赖于循环神经网络(RNN)和长短期记忆网络(LSTM)等序列模型。这些模型虽然能够处理序列数据,但在处理长序列时存在长期依赖问题。
2017年,Google提出了Transformer模型,其核心就是多头自注意力机制。MHSA能够并行处理输入序列,捕捉序列中的长距离依赖关系,大大提高了模型的性能和效率。自此,MHSA成为了NLP领域的主流技术,被广泛应用于各种大型语言模型中。
3. 多头自注意力机制的工作原理
让我们通过一个具体的例子,step by step地计算多头自注意力机制的输出。我们将遵循以下步骤:
3.1 定义输入
MHSA可以应用于各种类型的数据,但通常情况下,输入是一个向量序列。在自然语言处理中,这通常是词嵌入(word embedding)与位置编码(positional encoding)的组合。
假设我们有一个简单的输入序列,包含3个词,每个词用4维向量表示:
这个4×3的矩阵代表了我们的输入序列。
3.2 定义可学习参数
MHSA主要学习三个权重矩阵,用于构造”查询”(Query)、”键”(Key)和”值”(Value)。在本例中,我们假设模型已经学习到了以下权重矩阵:
这些4×2的矩阵代表了模型的可学习参数。
3.3 计算查询、键和值
接下来,我们将输入与权重矩阵相乘,得到查询、键和值:
让我们计算Query:
同理可得Key和Value:
3.4 划分多个注意力头
多头自注意力机制的”多头”体现在这一步。我们将Query、Key和Value划分为多个子矩阵,每个子矩阵对应一个注意力头。在本例中,我们使用两个注意力头:
这样,我们就得到了两组Query、Key和Value,分别用于两个注意力头的计算。
3.5 计算Z矩阵
接下来,我们需要计算Z矩阵,这是构造注意力矩阵的中间步骤。Z矩阵由Query和Key的矩阵乘法得到。我们以第一个注意力头为例:
为了防止Z矩阵的值随着序列长度的增加而过大,我们通常会将Z矩阵除以序列长度的平方根。在本例中,序列长度为3,所以我们将Z_1除以$\sqrt{3}$:
同理可得Z_2。
3.6 掩码操作
在某些应用场景中,如语言模型预测下一个词时,我们需要进行掩码操作,以确保模型在预测时不会”看到”未来的信息。这通常通过将Z矩阵中的某些位置设置为负无穷来实现。在本例中,我们假设不需要掩码操作。
3.7 计算注意力矩阵
注意力矩阵是通过对Z矩阵的每一行进行softmax运算得到的。softmax函数的定义如下:
$softmax(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}$
让我们以Z_1的第一行为例计算softmax:
对Z_1的每一行都进行这样的计算,我们就得到了注意力矩阵Attention_1:
同理可得Attention_2。
3.8 计算注意力头的输出
得到注意力矩阵后,我们将其与Value相乘,得到每个注意力头的输出:
同理可得Output_2。
3.9 拼接多个注意力头的输出
最后,我们将所有注意力头的输出拼接起来,得到多头自注意力机制的最终输出:
这个3×2的矩阵就是多头自注意力机制的输出结果。
4. 多头自注意力机制的优势
通过上述计算过程,我们可以看出多头自注意力机制具有以下优势:
5. 多头自注意力机制的应用
MHSA在自然语言处理领域有广泛的应用,包括但不限于:
6. 结论
多头自注意力机制是现代人工智能,特别是自然语言处理领域的核心技术之一。通过本文的详细计算过程,我们深入了解了MHSA的工作原理。尽管实际应用中的计算规模要大得多,但基本原理是相同的。
理解MHSA的工作原理对于深入学习和应用先进的AI技术至关重要。随着技术的不断发展,我们可以期待MHSA在更多领域发挥重要作用,推动人工智能技术的进步。
参考文献