这篇论文探索了使用8位浮点数(FP8)来有效训练大型语言模型(LLM)。论文的主要贡献和要点总结如下:
- 提出了一个新的FP8混合精度训练框架,可以分阶段地引入8位梯度、优化器状态和分布式训练,逐步发挥FP8的优势,降低训练成本。
- 在这个框架下,实现了8位梯度交流、8位优化器和8位并行训练。具体来说:
- 为FP8梯度交流设计了自动缩放技术,解决了低位交流中的上下溢问题。
- 实现了FP8优化器,通过精度解耦找到哪些变量更适合低精度表达。
- 在张量并行、流水线并行和序列并行中支持FP8,降低激活传递的通信量。
- 在7B到175B参数规模的GPT模型上验证了该FP8训练方案的效果。结果显示,相比BF16训练,FP8训练可以显著降低GPU内存占用(29%~39%)、权重相关通信量(63%~65%),并提高吞吐量。模型性能不受影响。
- 将FP8训练应用到GPT模型的微调上,包括教学调整和强化学习。结果同样展现出计算和内存上的节约。
- 通过大量的分析实验对FP8训练的设计选择进行了验证,为后续研究提供了指导性结论。
- 本文是第一个将FP8计算、存储和通信全面渗透到大模型训练 entire pipeline 的工作,可视为推动下一代低精度训练系统的重要一步。
本文对利用FP8进行大规模语言模型的高效低精度训练做出了重要探索,在减少训练成本方面展现出令人鼓舞的潜力。论文的贡献具有重要的理论和实践价值。