Transformer模型原理、架构与复现技巧详解

从理论到实践,全面解析深度学习领域的革命性模型

info 模型简介与背景

Transformer模型是由Google研究团队在2017年发表的论文《Attention Is All You Need》中提出的一种基于注意力机制的深度学习模型架构。它完全摒弃了传统的循环神经网络(RNN)和卷积神经网络(CNN)结构,仅依靠注意力机制来处理序列数据,实现了并行计算,显著提高了训练效率。

auto_awesome 革命性架构:完全基于注意力机制
speed 并行计算:训练速度大幅提升
psychology 长距离依赖:有效捕捉序列中的远距离关系
trending_up 广泛应用:NLP、CV、语音处理等领域

Transformer模型的出现标志着深度学习领域的一次重大突破,它不仅成为了机器翻译等序列到序列任务的主流模型,还催生了BERT、GPT等一系列强大的预训练语言模型,对人工智能的发展产生了深远影响。

# Transformer模型的基本结构
class Transformer(nn.Module):
  def __init__(self, src_vocab_size, tgt_vocab_size, d_model, n_heads, n_layers, dropout):
    super(Transformer, self).__init__()
    self.encoder = Encoder(src_vocab_size, d_model, n_heads, n_layers, dropout)
    self.decoder = Decoder(tgt_vocab_size, d_model, n_heads, n_layers, dropout)
    self.fc = nn.Linear(d_model, tgt_vocab_size)
Transformer模型核心原理

Transformer模型核心原理

注意力机制、自注意力机制与多头注意力机制详解

visibility 注意力机制 (Attention Mechanism)

注意力机制是Transformer的核心创新,它允许模型在处理序列数据时动态关注不同位置的信息。注意力机制通过计算查询(Query)、键(Key)和值(Value)之间的相似度,确定每个位置的重要性权重。

Attention(Q, K, V) = softmax(QKT/√dk)V
1
计算相似度分数:通过查询向量Q与所有键向量K的点积,计算每个位置的相关性分数。
2
缩放与归一化:将分数除以√dk(dk是键向量的维度)进行缩放,然后通过softmax函数转换为概率分布。
3
加权求和:使用归一化后的权重对值向量V进行加权求和,得到最终的注意力输出。
def scaled_dot_product_attention(Q, K, V):
  # 计算注意力分数
  scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
  # Softmax归一化
  attention_weights = F.softmax(scores, dim=-1)
  # 加权求和
  output = torch.matmul(attention_weights, V)
  return output, attention_weights

loop 自注意力机制 (Self-Attention)

自注意力机制是Transformer的核心组件,它允许序列中的每个位置都能够关注到序列中的所有其他位置。在自注意力中,查询(Q)、键(K)和值(V)都来自同一个输入序列,使模型能够捕捉序列内部的依赖关系。

1
线性变换:输入序列X通过三个不同的权重矩阵WQ、WK和WV线性变换,生成查询(Q)、键(K)和值(V)向量。
2
注意力计算:使用上述注意力机制计算每个位置与其他位置的相关性。
3
并行计算:与RNN不同,自注意力机制可以并行处理整个序列,大大提高了计算效率。
class SelfAttention(nn.Module):
  def __init__(self, embed_dim):
    super(SelfAttention, self).__init__()
    self.query = nn.Linear(embed_dim, embed_dim)
    self.key = nn.Linear(embed_dim, embed_dim)
    self.value = nn.Linear(embed_dim, embed_dim)

  def forward(self, x):
    Q = self.query(x) # [batch_size, seq_len, embed_dim]
    K = self.key(x) # [batch_size, seq_len, embed_dim]
    V = self.value(x) # [batch_size, seq_len, embed_dim]
    return scaled_dot_product_attention(Q, K, V)

device_hub 多头注意力机制 (Multi-Head Attention)

多头注意力机制通过将查询、键和值投影到不同的子空间,并行执行多个自注意力计算,然后将结果拼接并线性变换。这使得模型能够同时关注不同位置和不同表示子空间的信息。

MultiHead(Q, K, V) = Concat(head1, ..., headh)WO
其中 headi = Attention(QWiQ, KWiK, VWiV)
1
线性投影:将Q、K、V分别通过h个不同的线性层投影到h个不同的子空间。
2
并行计算:在每个子空间上独立执行缩放点积注意力计算。
3
拼接与变换:将所有头的输出拼接起来,并通过一个线性层进行最终变换。
class MultiHeadAttention(nn.Module):
  def __init__(self, embed_dim, num_heads):
    super(MultiHeadAttention, self).__init__()
    self.embed_dim = embed_dim
    self.num_heads = num_heads
    self.head_dim = embed_dim // num_heads
    self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim)
    self.out_proj = nn.Linear(embed_dim, embed_dim)

  def forward(self, x):
    batch_size, seq_len, _ = x.shape
    qkv = self.qkv_proj(x) # [batch_size, seq_len, 3*embed_dim]
    qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
    qkv = qkv.permute(2, 0, 3, 1, 4) # [3, batch_size, num_heads, seq_len, head_dim]
    Q, K, V = qkv[0], qkv[1], qkv[2]
    attn_output, _ = scaled_dot_product_attention(Q, K, V)
    attn_output = attn_output.permute(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
    return self.out_proj(attn_output)
Transformer模型架构

Transformer模型架构

编码器、解码器、位置编码与前馈神经网络详解

architecture 整体架构

Transformer模型由编码器(Encoder)和解码器(Decoder)两部分组成,它们都由多层相同的子层堆叠而成。编码器负责处理输入序列并提取特征,解码器则基于编码器的输出和已生成的序列来生成目标序列。

Transformer模型架构图
input 编码器 (Encoder)
由N个相同的层堆叠而成,每层包含一个多头自注意力子层和一个前馈神经网络子层,每个子层都有残差连接和层归一化。
output 解码器 (Decoder)
同样由N个相同的层堆叠而成,除了包含编码器的两个子层外,还增加了一个多头交叉注意力子层,用于关注编码器的输出。

location_on 位置编码 (Positional Encoding)

由于Transformer模型没有循环或卷积结构,无法捕捉序列中的顺序信息。位置编码通过向输入嵌入中添加位置信息,使模型能够区分不同位置的词。

PE(pos,2i) = sin(pos/100002i/dmodel)
PE(pos,2i+1) = cos(pos/100002i/dmodel)

其中,pos是位置,i是维度索引,dmodel是嵌入维度。这种设计使得模型能够轻松学习到相对位置关系,因为对于固定的偏移量k,PEpos+k可以表示为PEpos的线性函数。

class PositionalEncoding(nn.Module):
  def __init__(self, d_model, max_len=5000):
    super(PositionalEncoding, self).__init__()
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                   (-math.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    self.register_buffer('pe', pe)

  def forward(self, x):
    # x shape: [batch_size, seq_len, embedding_dim]
    x = x + self.pe[:x.size(1), :] # 添加位置编码
    return x

linear_scale 前馈神经网络 (Feed-Forward Network)

Transformer中的每个编码器层和解码器层都包含一个全连接的前馈神经网络,该网络由两个线性变换和一个非线性激活函数组成。这个子层分别对每个位置独立地应用相同的操作。

FFN(x) = max(0, xW1 + b1)W2 + b2

其中,W1和W2是权重矩阵,b1和b2是偏置向量。在原始论文中,输入和输出的维度为dmodel=512,内部层的维度为dff=2048,即扩展了4倍。

class PositionwiseFeedForward(nn.Module):
  def __init__(self, d_model, d_ff, dropout=0.1):
    super(PositionwiseFeedForward, self).__init__()
    self.w_1 = nn.Linear(d_model, d_ff) # 扩展层
    self.w_2 = nn.Linear(d_ff, d_model) # 投影层
    self.dropout = nn.Dropout(dropout)
    self.activation = nn.ReLU()

  def forward(self, x):
    # x shape: [batch_size, seq_len, d_model]
    x = self.w_1(x)
    x = self.activation(x)
    x = self.dropout(x)
    x = self.w_2(x)
    return x

add_link 残差连接与层归一化

每个子层(多头注意力和前馈神经网络)的输出都经过残差连接和层归一化处理。残差连接有助于缓解深度网络中的梯度消失问题,层归一化则稳定了训练过程。

LayerNorm(x + Sublayer(x))

其中,Sublayer(x)是子层本身实现的函数。在原始Transformer中,层归一化是在残差连接之后应用的,这种结构被称为"Post-LN"。后续研究表明,将层归一化放在残差连接之前("Pre-LN")可以提高训练稳定性,特别是在深层Transformer中。

class SublayerConnection(nn.Module):
  def __init__(self, size, dropout):
    super(SublayerConnection, self).__init__()
    self.norm = nn.LayerNorm(size)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, sublayer):
    # Post-LN: 先应用子层,然后残差连接,最后层归一化
    return self.norm(x + self.dropout(sublayer(x)))

# Pre-LN版本
def forward(self, x, sublayer):
    # Pre-LN: 先层归一化,然后应用子层,最后残差连接
    return x + self.dropout(sublayer(self.norm(x)))
Transformer模型设计思想

Transformer模型设计思想

创新点、设计理念与模型对比

lightbulb 核心设计理念

Transformer模型的设计理念源于对序列处理方式的根本性思考。其核心思想是完全依赖注意力机制来捕捉序列中的依赖关系,摒弃了传统的循环和卷积结构,实现了更高效的并行计算和更强的长距离依赖建模能力。

"我们认为注意力机制本身足以取代传统序列建模中的循环和卷积结构,并提出了Transformer模型,完全基于注意力机制来构建序列到序列的转换。"

— 《Attention Is All You Need》论文摘要

Transformer的设计体现了以下几个关键理念:

speed 并行计算
摒弃RNN的顺序处理方式,通过自注意力机制实现整个序列的并行处理,大幅提高训练效率。
timeline 长距离依赖
通过注意力机制直接建立序列中任意两个位置的联系,有效捕捉长距离依赖关系,解决了RNN中的梯度消失问题。
layers 多层次表示
通过堆叠多层编码器和解码器,学习不同层次的抽象表示,从低级特征到高级语义逐步构建。
auto_awesome 多视角学习
多头注意力机制允许模型同时关注不同位置和不同表示子空间的信息,从多个角度理解输入序列。

compare 与传统模型的对比

Transformer模型与传统的序列处理模型(如RNN、LSTM、GRU和CNN)在设计理念和实现方式上有显著区别。下表对比了Transformer与传统模型的主要差异:

特性 Transformer RNN/LSTM/GRU CNN
计算方式 并行计算 顺序计算 局部并行
长距离依赖 直接建模,距离无关 通过门控机制缓解 需要多层堆叠
位置信息 显式位置编码 隐式顺序处理 局部卷积核
计算复杂度 O(n²) O(n) O(k·n),k为核大小
可解释性 注意力权重可视化 较难解释 特征图可视化

Transformer模型的主要优势在于:

  • 计算效率:并行计算使得训练速度大幅提升,尤其适合大规模数据集和长序列处理。
  • 长距离依赖:直接建立任意位置之间的联系,不受距离限制,更好地捕捉全局信息。
  • 可扩展性:模型结构简单统一,易于扩展到大规模参数和海量数据。
  • 多任务适用性:不仅适用于NLP任务,还可扩展到CV、语音处理等多个领域。

psychology 设计哲学与影响

Transformer模型的设计哲学体现了简单性、有效性和可扩展性的原则。通过摒弃复杂的循环结构和门控机制,Transformer用相对简单的注意力机制实现了更强大的性能,这一设计思想对后续的深度学习研究产生了深远影响。

Transformer的设计哲学主要体现在以下几个方面:

architecture 模块化设计
将模型分解为注意力、前馈网络、归一化等标准模块,每个模块功能明确,易于理解和实现。
repeat 统一结构
编码器和解码器采用相似的层结构,减少了设计复杂性,便于模型扩展和优化。
add_link 残差连接
通过残差连接和层归一化,使深层网络训练更加稳定,为构建大型模型奠定基础。
settings 超参数简化
相比传统模型,Transformer减少了需要调整的超参数数量,使模型训练更加简单。

Transformer的设计思想不仅推动了NLP领域的发展,还启发了计算机视觉、语音处理、多模态学习等多个领域的研究。BERT、GPT、T5等大型预训练模型都基于Transformer架构,展示了其强大的可扩展性和通用性。可以说,Transformer不仅是一种模型架构,更是一种全新的深度学习设计范式。

复现Transformer时的技巧和经验分享

复现Transformer时的技巧和经验分享

解决复现过程中常见的挑战与问题

tune 超参数设置技巧

复现Transformer时,正确的超参数设置是获得与论文相近性能的关键。以下是论文中使用的超参数以及一些实用技巧:

straighten 模型维度
原始论文使用dmodel=512,这是模型的基础维度。前馈网络内部层维度dff=2048,是模型维度的4倍。这个比例在后续研究中被广泛采用。
view_comfy 注意力头数
论文使用h=8个注意力头,每个头的维度dk=dv=dmodel/h=64。确保dmodel能被h整除,否则可能导致维度不匹配问题。
layers 层数设置
编码器和解码器各使用N=6层。层数越多,模型表达能力越强,但也更容易过拟合。可以从较少层数开始,逐步增加。
opacity Dropout率
论文在多个位置应用Dropout:嵌入层后(Pdrop=0.1)、子层输出后(Pdrop=0.1)和注意力权重后(Pdrop=0.1)。适当调整Dropout率可以有效防止过拟合。
warning
常见误区

很多人在复现时忽略了标签平滑(label smoothing)技术,论文中使用εls=0.1的标签平滑,这对提高模型泛化能力很重要。

fitness_center 训练技巧与优化方法

Transformer模型的训练过程有其特殊性,以下是一些关键的训练技巧和优化方法,可以帮助你获得更好的性能:

speed 学习率调度
使用自定义的学习率调度器:lrate = dmodel-0.5 · min(step_num-0.5, step_num · warmup_steps-1.5)。这种先线性增加后衰减的学习率策略对Transformer训练至关重要。
gradient 梯度裁剪
论文使用梯度裁剪来防止梯度爆炸,将梯度范数限制在1.0以内。这在训练初期尤其重要,可以稳定训练过程。
memory 优化器选择
使用Adam优化器,参数设置为β1=0.9, β2=0.98, ε=10-9。这些参数值与标准Adam不同,对Transformer模型的收敛有显著影响。
format_align_center 批次大小
论文使用批次大小约为4096个token。可以通过梯度累积来模拟大批次训练,尤其是在GPU内存有限的情况下。
# 学习率调度器实现
class NoamOpt:
  def __init__(self, model_size, factor, warmup, optimizer):
    self.optimizer = optimizer
    self.warmup = warmup
    self.factor = factor
    self.model_size = model_size
    self._step = 0

  def step(self):
    self._step += 1
    lr = self.factor * (self.model_size ** (-0.5) *
          min(self._step ** (-0.5), self._step * self.warmup ** (-1.5)))
    for param_group in self.optimizer.param_groups:
      param_group['lr'] = lr
    self.optimizer.step()

code 实现细节与常见陷阱

在复现Transformer时,一些实现细节容易被忽略,但这些细节往往对模型性能有重要影响。以下是一些关键实现细节和常见陷阱:

swap_vert 层归一化位置
原始论文使用Post-LN(层归一化在残差连接之后),但后续研究表明Pre-LN(层归一化在残差连接之前)对深层Transformer更稳定。尝试两种方式,观察哪种更适合你的任务。
filter_center_focus 注意力掩码
解码器中的自注意力需要使用因果掩码(causal mask),防止当前位置关注到未来位置。同时,填充掩码(padding mask)用于忽略填充位置的影响。
functions 初始化策略
使用Xavier/Glorot初始化或PyTorch默认初始化。对于输出层,可以将权重初始化为输入嵌入权重的转置,这在某些任务中有助于收敛。
precision_manufacturing 混合精度训练
使用FP16混合精度训练可以显著减少内存使用并加速训练,但需要谨慎处理梯度缩放,避免数值不稳定问题。
warning
性能差距原因

如果你复现的Transformer比论文低两点,可能的原因包括:1) 数据预处理不一致;2) 优化器参数设置不正确;3) 学习率调度器实现有误;4) 未使用标签平滑;5) 位置编码实现有偏差。检查这些细节通常能解决大部分性能差距问题。

insights 调优经验与最佳实践

基于社区经验和研究进展,以下是一些Transformer模型调优的最佳实践,可以帮助你进一步提升模型性能:

data_usage 数据质量与数量
确保使用与论文相同的数据集和预处理方法。数据质量对模型性能影响巨大,清洗数据、统一格式、正确分词是基础工作。
architecture 归一化方法
尝试不同的归一化方法,如RMSNorm或DeepNorm,它们在某些场景下比标准LayerNorm表现更好,训练更稳定。
bolt 激活函数
原始论文使用ReLU,但后续研究表明GELU、Swish等激活函数在Transformer中表现更好。可以尝试不同的激活函数。
trending_up 监控与调试
使用TensorBoard等工具监控训练过程,特别关注梯度范数、学习率变化和注意力权重的分布,这些信息有助于诊断问题。

最后,记住复现研究是一个迭代过程。如果第一次尝试没有达到预期性能,不要气馁。仔细检查每个组件,参考开源实现,逐步调整,最终你将能够成功复现Transformer的卓越性能。

关键代码示例

关键代码示例

Transformer模型核心组件实现与细节

visibility 缩放点积注意力实现

缩放点积注意力是Transformer的核心组件,以下是PyTorch实现:

Python
import torch
import torch.nn.functional as F
import math

class ScaledDotProductAttention(torch.nn.Module):
  """缩放点积注意力机制"""
  def __init__(self, dropout=0.1):
    super(ScaledDotProductAttention, self).__init__()
    self.dropout = torch.nn.Dropout(dropout)

  def forward(self, q, k, v, mask=None):
    # q, k, v的形状: [batch_size, num_heads, seq_len, d_k]
    d_k = k.size(-1)

    # 1. 计算注意力分数
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)

    # 2. 应用掩码(如果有)
    if mask is not None:
      scores = scores.masked_fill(mask == 0, -1e9)

    # 3. 计算注意力权重
    attn_weights = F.softmax(scores, dim=-1)
    attn_weights = self.dropout(attn_weights)

    # 4. 加权求和得到输出
    output = torch.matmul(attn_weights, v)

    return output, attn_weights

关键点:

  • 除以√dk进行缩放,防止点积值过大导致softmax梯度消失
  • 掩码处理用于忽略填充位置或防止解码器关注未来位置
  • 返回注意力权重可用于可视化和解释模型行为

device_hub 多头注意力实现

多头注意力允许模型同时关注不同位置和不同表示子空间的信息:

Python
class MultiHeadAttention(torch.nn.Module):
  """多头注意力机制"""
  def __init__(self, d_model, num_heads, dropout=0.1):
    super(MultiHeadAttention, self).__init__()
    self.d_model = d_model
    self.num_heads = num_heads
    self.d_k = d_model // num_heads
    self.d_v = d_model // num_heads

    # 线性变换层
    self.w_q = torch.nn.Linear(d_model, d_model)
    self.w_k = torch.nn.Linear(d_model, d_model)
    self.w_v = torch.nn.Linear(d_model, d_model)
    self.w_o = torch.nn.Linear(d_model, d_model)

    # 注意力机制
    self.attention = ScaledDotProductAttention(dropout)

    # Dropout和层归一化
    self.dropout = torch.nn.Dropout(dropout)
    self.layer_norm = torch.nn.LayerNorm(d_model)

  def forward(self, q, k, v, mask=None):
    batch_size = q.size(0)

    # 残差连接
    residual = q

    # 1. 线性变换
    q = self.w_q(q) # [batch_size, seq_len, d_model]
    k = self.w_k(k) # [batch_size, seq_len, d_model]
    v = self.w_v(v) # [batch_size, seq_len, d_model]

    # 2. 分割成多头
    q = q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
    k = k.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
    v = v.view(batch_size, -1, self.num_heads, self.d_v).transpose(1, 2)

    # 3. 应用注意力
    output, attn = self.attention(q, k, v, mask=mask)

    # 4. 合并多头
    output = output.transpose(1, 2).contiguous()
    output = output.view(batch_size, -1, self.d_model)

    # 5. 最终线性变换
    output = self.w_o(output)
    output = self.dropout(output)

    # 6. 残差连接和层归一化
    output = self.layer_norm(output + residual)

    return output, attn

关键点:

  • 使用view和transpose操作实现多头分割和合并
  • 残差连接和层归一化有助于稳定深层网络训练
  • 返回注意力权重可用于可视化和解释模型行为

location_on 位置编码实现

位置编码为模型提供序列中元素的位置信息,使用正弦和余弦函数生成:

Python
class PositionalEncoding(torch.nn.Module):
  """位置编码"""
  def __init__(self, d_model, max_len=5000):
    super(PositionalEncoding, self).__init__()

    # 计算位置编码
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                   (-math.log(10000.0) / d_model))

    # 偶数位置使用sin,奇数位置使用cos
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)

    # 注册为缓冲区,不作为模型参数
    self.register_buffer('pe', pe.unsqueeze(0))

  def forward(self, x):
    # x形状: [batch_size, seq_len, d_model]
    # 添加位置编码
    x = x + self.pe[:, :x.size(1), :]
    return x

关键点:

  • 使用register_buffer注册位置编码,不作为模型参数更新
  • 正弦和余弦函数的组合使模型能够学习相对位置关系
  • 位置编码可以直接加到词嵌入上,无需额外训练

input 编码器层实现

编码器层包含多头自注意力和前馈神经网络两个子层:

Python
class EncoderLayer(torch.nn.Module):
  """编码器层"""
  def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
    super(EncoderLayer, self).__init__()

    # 多头自注意力
    self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)

    # 前馈神经网络
    self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout)

  def forward(self, x, mask=None):
    # 1. 多头自注意力
    x, attn = self.self_attn(x, x, x, mask=mask)

    # 2. 前馈神经网络
    x = self.ffn(x)

    return x, attn

class PositionwiseFeedForward(torch.nn.Module):
  """前馈神经网络"""
  def __init__(self, d_model, d_ff, dropout=0.1):
    super(PositionwiseFeedForward, self).__init__()
    self.w_1 = torch.nn.Linear(d_model, d_ff)
    self.w_2 = torch.nn.Linear(d_ff, d_model)
    self.dropout = torch.nn.Dropout(dropout)
    self.activation = torch.nn.ReLU()

  def forward(self, x):
    # x形状: [batch_size, seq_len, d_model]
    x = self.w_1(x)
    x = self.activation(x)
    x = self.dropout(x)
    x = self.w_2(x)
    return x

关键点:

  • 编码器层由多头自注意力和前馈神经网络组成
  • 前馈神经网络内部层维度通常是模型维度的4倍(d_ff=4*d_model)
  • 每个子层都包含残差连接和层归一化(在MultiHeadAttention类中实现)

architecture 完整Transformer模型实现

将所有组件组合成完整的Transformer模型:

Python
class Transformer(torch.nn.Module):
  """完整的Transformer模型"""
  def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512,
              num_heads=8, num_layers=6, d_ff=2048, dropout=0.1):
    super(Transformer, self).__init__()

    # 编码器和解码器的嵌入层
    self.src_embedding = torch.nn.Embedding(src_vocab_size, d_model)
    self.tgt_embedding = torch.nn.Embedding(tgt_vocab_size, d_model)

    # 位置编码
    self.positional_encoding = PositionalEncoding(d_model)

    # 编码器
    self.encoder_layers = torch.nn.ModuleList([
      EncoderLayer(d_model, num_heads, d_ff, dropout)
      for _ in range(num_layers)
    ])

    # 解码器
    self.decoder_layers = torch.nn.ModuleList([
      DecoderLayer(d_model, num_heads, d_ff, dropout)
      for _ in range(num_layers)
    ])

    # 输出层
    self.fc_out = torch.nn.Linear(d_model, tgt_vocab_size)

    # 初始化参数
    self._init_weights()

  def _init_weights(self):
    # 使用Xavier初始化
    for p in self.parameters():
      if p.dim() > 1:
        torch.nn.init.xavier_uniform_(p)

  def encode(self, src, src_mask):
    # 嵌入和位置编码
    src = self.src_embedding(src) * math.sqrt(self.d_model)
    src = self.positional_encoding(src)

    # 通过所有编码器层
    for layer in self.encoder_layers:
      src, _ = layer(src, src_mask)

    return src

  def decode(self, tgt, encoder_output, tgt_mask, src_tgt_mask):
    # 嵌入和位置编码
    tgt = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
    tgt = self.positional_encoding(tgt)

    # 通过所有解码器层
    for layer in self.decoder_layers:
      tgt, _ = layer(tgt, encoder_output, tgt_mask, src_tgt_mask)

    # 输出层
    output = self.fc_out(tgt)
    return output

  def forward(self, src, tgt, src_mask, tgt_mask, src_tgt_mask):
    # 编码
    encoder_output = self.encode(src, src_mask)

    # 解码
    output = self.decode(tgt, encoder_output, tgt_mask, src_tgt_mask)

    return output

关键点:

  • 嵌入层输出乘以√d_model,与位置编码相加
  • 使用ModuleList堆叠多个编码器层和解码器层
  • Xavier初始化有助于模型稳定训练
  • 将编码和解码过程分离,便于灵活使用