引言
大语言模型(LLM)的训练一直是AI领域的热点话题。随着开源模型的不断涌现,如何对这些基础模型进行进一步优化和定制化训练成为了很多研究者和开发者关注的焦点。本文将介绍如何使用Firefly框架在单张V100 GPU上对Qwen1.5-7B模型进行SFT(Supervised Fine-tuning)和DPO(Direct Preference Optimization)训练,并探讨训练过程中的关键技术点和实验结果。
Firefly简介
Firefly是一个开源的大模型一站式训练框架,支持对各种主流大模型进行预训练、指令微调和DPO等训练。它支持全量参数、LoRA、QLoRA等多种训练方式,可以适应不同的硬件条件和训练需求。Firefly框架兼容包括Gemma、Qwen1.5、MiniCPM、Mixtral-8x7B、Mistral、Llama等在内的绝大多数主流大模型。
Qwen1.5模型介绍
Qwen1.5是阿里巴巴在2024年春节前开源的大语言模型,支持32K的上下文长度。该模型可以看作是Qwen2的beta版本,未来还会有Qwen2的正式版本发布。从各项评测结果来看,Qwen1.5各个尺寸的模型都显著优于同量级的Llama2。在2024年2月的SuperCLUE大模型榜单中,Qwen1.5也展现出了非常优秀的表现,在开源模型中处于领先地位。
大模型训练的三个阶段
大模型的训练通常可以分为以下三个主要阶段:
- 预训练(Pre-training): 使用超大规模文本对模型进行训练,训练任务为"预测下一个token"。这个阶段通常需要处理几万亿个token的数据量。
- SFT(Supervised Fine-tuning,指令微调): 使用指令数据对模型进行微调,使其输出格式与人类对齐,具备对话(chat)的能力。
- RLHF(Reinforcement Learning from Human Feedback,基于人类反馈的强化学习): 使用人类反馈或偏好数据来训练模型,使模型的输出更加符合人类的价值观或预期行为。
DPO简介
在RLHF阶段,传统的方法如PPO(Proximal Policy Optimization)存在流程繁琐、显存需求大等问题。相比之下,DPO(Direct Preference Optimization)方法绕过了奖励模型的构建,可以直接使用人类偏好数据对模型进行训练,且在训练时仅需加载策略网络和参考网络,极大地节省了显存占用。
DPO的训练数据包含三个字段:prompt、chosen和rejected。其损失函数计算过程具有对称性,公式如下:
L_DPO = -log(σ(β(r_θ(x,y) - r_θ(x,y')))) + log(σ(β(r_θ_ref(x,y) - r_θ_ref(x,y'))))
其中,r_θ表示策略网络,r_θ_ref表示参考网络,β是温度系数,σ是sigmoid函数。
在代码实现中,DPO损失函数的计算过程大致如下:
- 计算对数概率:将prompt分别与chosen和rejected拼接,然后分别输入策略网络和参考网络,得到4个对数概率。
- 计算策略网络的diff:策略网络的chosen对数概率 - rejected对数概率。
- 计算参考网络的diff:参考网络的chosen对数概率 - rejected对数概率。
- 计算损失函数:策略网络的diff - 参考网络的diff。
实验设置
本实验在Qwen1.5-7B的基础上,使用Firefly框架进行了SFT和DPO两阶段的训练。整个训练流程仅使用一张V100 GPU,采用QLoRA技术,在所有Linear层都添加adapter以提升训练效果。两个阶段均使用英文数据进行训练。
对话模板
Firefly与Qwen1.5官方的对话模板保持一致:
<|im_start|>system
You are a helpful assistant.
<|im_end|>
<|im_start|>user
hello, who are you?
<|im_end|>
<|im_start|>assistant
I am a AI program developed by Firefly
<|im_end|>
SFT阶段设置
使用Firefly对Qwen1.5进行SFT的启动命令:
python train.py --train_args_file train_args/sft/qlora/qwen1.5-7b-sft-qlora.json
SFT阶段的主要参数设置如下:
- num_epochs: 1
- learning_rate: 2e-4
- total_train_batch_size: 32
- max_seq_length: 2048
- optimizer: paged_adamw_32bit
- lr_scheduler_type: constant_with_warmup
- warmup_steps: 700
- lora_rank: 64
- lora_alpha: 16
- lora_dropout: 0.05
- gradient_checkpointing: true
- fp16: true
DPO阶段设置
使用Firefly对Qwen1.5进行DPO的启动命令:
python train.py --train_args_file train_args/dpo/qlora/qwen1.5-7b-dpo-qlora.json
DPO阶段采用ultrafeedback数据集,主要参数设置如下:
- num_epochs: 1
- learning_rate: 2e-4
- total_train_batch_size: 32
- max_seq_length: 1600
- max_prompt_length: 500
- optimizer: paged_adamw_32bit
- lr_scheduler_type: constant_with_warmup
- warmup_steps: 200
- lora_rank: 64
- lora_alpha: 16
- lora_dropout: 0.05
- gradient_checkpointing: true
- fp16: true
实验结果与分析
模型评测
在Open LLM Leaderboard上对模型进行评测,Firefly训练的模型表现显著优于官方的Qwen1.5-7B-Chat、Gemma-7B-it等模型。具体来说:
- 比Qwen1.5-7B-Chat高7.12分
- 比Gemma-7B-it高8.8分
经过DPO之后,模型的平均分还有接近1分左右的提升。这说明Firefly框架在单卡V100上通过SFT和DPO训练,成功地提升了Qwen1.5模型的性能。
DPO训练指标分析
在DPO训练过程中,我们关注了几个重要的训练指标:
- DPO训练loss:
训练过程中,loss呈现出总体下降的趋势,表明模型在逐步优化。 - Rewards/accuracies:
该指标表示较优回答的奖励大于较劣回答的奖励的频率的均值。在训练过程中,这个指标呈现上升趋势,说明模型越来越能够区分优质和劣质回答。 - Rewards/margins:
该指标表示较优回答的奖励与较劣回答的奖励二者之差的均值。这个指标也呈现上升趋势,表明模型对优质回答的偏好程度在不断增强。
这些指标的变化趋势都表明,DPO训练确实帮助模型学习到了人类的偏好,提升了模型输出的质量。
结论与展望
通过使用Firefly框架在单卡V100上对Qwen1.5-7B模型进行SFT和DPO训练,我们成功地提升了模型的性能,在Open LLM Leaderboard上取得了优于原始Qwen1.5-7B-Chat和Gemma-7B-it等模型的成绩。这个实验结果表明:
- Firefly框架为大模型的定制化训练提供了高效且易用的解决方案。
- QLoRA技术使得在有限的硬件资源上也能进行有效的大模型微调。
- DPO方法相比传统的RLHF方法,简化了训练流程,降低了资源需求,同时也取得了良好的效果。
未来的研究方向可以包括:
- 探索更优的超参数组合,进一步提升模型性能。
- 尝试在不同规模的模型上应用Firefly框架,研究其扩展性。
- 结合领域特定数据,探索Firefly框架在垂直领域的应用潜力。
- 研究如何进一步优化DPO算法,提高其训练效率和效果。
总的来说,Firefly框架为大模型的定制化训练提供了一个强大而灵活的工具,为AI研究者和开发者开辟了新的可能性。我们期待看到更多基于Firefly的创新应用和研究成果。
参考文献
- YeungNLP. (2024). 使用Firefly在单卡V100上对Qwen1.5进行SFT和DPO,大幅超越Qwen1.5和Gemma. 微信公众号文章.
- Rafailov, R., et al. (2023). Direct Preference Optimization: Your Language Model is Secretly a Reward Model. arXiv preprint arXiv:2305.18290.
- Qwen Team. (2024). Qwen1.5: An Open Source AI Model by Alibaba. https://github.com/QwenLM/Qwen
- Firefly Team. (2024). Firefly: An Open-source One-stop Large Language Model Training Framework. https://github.com/YeungNLP/firefly