元鉴
返回中文阅读流

NVIDIA Developer Blog

使用端到端 FP8 精度运行高通量强化学习训练

随着大语言模型(LLM)从简单的文本生成转向复杂的推理,强化学习(RL)发挥着核心作用。像 Group Relative Policy 这样的算法...

中文内容

已翻译official company source英文原文2026-05-26

随着 LLM 从简单的文本生成转向复杂推理,强化学习(RL)发挥着核心作用。Group Relative Policy Optimization(GRPO)等算法推动了这一转变,使推理级模型能够通过迭代反馈持续改进。与标准的监督式微调不同,RL 训练循环被分为两个截然不同的高强度阶段:一个具有严格延迟要求的生成阶段,以及一个需要高吞吐量的训练阶段。

为了让这些工作负载具备可行性,研究人员和工程师正转向 FP8 等低精度数据类型,以提升训练性能和面向吞吐量的生成能力。此外,在某些生成受 GPU 内存带宽限制的场景中,使用低精度参数可以通过减少每个参数所需的字节数来提升性能。

本文深入探讨低精度 RL 的系统性挑战,以及 NVIDIA NeMo RL——NVIDIA NeMo 框架内的一个开源库——如何在保持准确性的同时加速 RL 工作负载。

RL 中线性层的 FP8

我们的方案使用 DeepSeek-V3 Technical Report 中介绍的分块量化 FP8。表 1 给出了线性投影层中张量格式的详细信息。

TensorType of dataQuantization granularityScaling factor Type of ScalingWeightsFP8 (E4M3)[128, 128]FP32Block-wiseInput activationsFP8 (E4M3)[1, 128]FP32Block-wiseOutput gradientsFP8 (E4M3)[1, 128]FP32Block-wise
表 1. 线性投影层中的张量格式

采用该方案,线性层可以使用 FP8 数学计算,其峰值吞吐量是 BF16 数学计算的 2 倍。其他模块,包括注意力、归一化、非线性函数和输出投影,则使用 BF16 数学计算。

强化学习中数值不一致的挑战

RL 流水线通常使用独立的引擎:vLLM 用于 rollouts,NVIDIA Megatron Core 用于训练。每个引擎都使用独特的自定义 NVIDIA CUDA 内核来最大化性能。这本质上会引入数值差异,而在较低精度下,由于额外的量化和反量化逻辑,这些差异会累积放大。我们将这种数值差异量化为 token 乘性概率误差:

正文:\(\texttt{token-mult-prob-error} = \frac{1}{n}\sum_{i=1}^{n(\texttt{tokens})}exp(\left| \texttt{log-train-fwk}_i – \texttt{logprobs-inference-fwk}_i \right|)\)

完全对齐的得分为 1;在不使用任何额外技术时,我们通常认为“小于 1.03–1.05”的值是“可接受的”。

线性层中的端到端 FP8 可减少数值分歧

在开发 FP8 方案的过程中,我们试验了三种方案:

  • 基线方案:生成和训练均使用 BF16。
  • 候选方案 1:FP8 仅在生成过程中使用,而策略模型训练采用 BF16。
  • 最终方案:端到端 FP8:我们在生成和训练引擎中均使用 FP8。

我们观察到,与仅在生成阶段使用 FP8 的候选方案 1 相比,最终方案在生成与训练之间始终表现出更低的数值分歧。请注意,基线方案始终给出最低的数值分歧。图 1 展示了三种方案的 token 乘法概率误差指标。

Line chart of token multiplicative probability error across 500 training steps, where the end-to-end FP8 recipe shows lower numerical disagreement than FP8 generation-only, though the BF16 baseline remains the most stable near 1.01.Line chart of token multiplicative probability error across 500 training steps, where the end-to-end FP8 recipe shows lower numerical disagreement than FP8 generation-only, though the BF16 baseline remains the most stable near 1.01.
图 1. 三种方案中的 token 乘法概率误差

通过重要性采样缓解数值分歧

重要性采样用于校正生成数据的模型(即分布)与正在训练的模型(即分布)之间的分布不匹配。它是一个逐 token 权重,与损失相乘。有关重要性采样的详细理论背景,可以参阅我们的 GRPO 文档。

实验表明:

  • 对于候选方案 1(FP8 生成和 BF16 训练),重要性采样可以缩小与 BF16 RL 的准确率差距,但无法完全弥合这一差距。
  • 对于最终方案(端到端 FP8),重要性采样完全弥合了与 BF16 训练的差距。图 2 展示了不同方案在训练期间的验证准确率。
Line chart of validation accuracy over 4k GRPO training steps on Llama 3.1 8B Instruct, where end-to-end FP8 with importance sampling fully closes the accuracy gap to the BF16 baseline (~0.62), while FP8 generation with BF16 training narrowLine chart of validation accuracy over 4k GRPO training steps on Llama 3.1 8B Instruct, where end-to-end FP8 with importance sampling fully closes the accuracy gap to the BF16 baseline (~0.62), while FP8 generation with BF16 training narrow
图 2. 在 Llama 3.1 8B Instruct 模型和数学数据集上进行 GRPO 训练的验证准确率

FP8 线性层端到端结果

我们在稠密模型和混合专家模型上评估端到端 FP8 方案,衡量其相对于 BF16 基线的验证准确率和训练吞吐量。

稠密模型上的 FP8 端到端:Llama 3.1 8B Instruct

表 2 显示了在 Llama 3.1 8B instruct 模型和数学数据集上训练至 4000 步时,GRPO 训练中 FP8 端到端方案与 BF16 方案的准确率。

PrecisionBF16FP8 generation onlyFP8 End-to-EndValidation accuracy0.6160.5860.613
表 2:不同精度配置下 Llama3 8B 验证准确率的准确率结果

在加速方面,与 BF16 相比,FP8 方案实现了稳定超过 15% 的吞吐量提升。图 3 展示了两种方案在 1000 个步骤中的 GRPO 训练吞吐量(每个 GPU 每秒处理的 token 数)。

Line chart of training throughput in tokens per second per GPU over 1k steps, where the end-to-end FP8 recipe consistently achieves over 15% higher throughput than the BF16 baseline, averaging around 1700 vs 1400 tokens/sec/GPU.Line chart of training throughput in tokens per second per GPU over 1k steps, where the end-to-end FP8 recipe consistently achieves over 15% higher throughput than the BF16 baseline, averaging around 1700 vs 1400 tokens/sec/GPU.
图 3。两种方案的吞吐量(每个 GPU 每秒处理的 token 数)(蓝色:BF16,粉色:FP8 端到端)

尽管从理论上看,FP8 相比 BF16 的加速比为 2 倍,但在实践中会更低,因为只有线性层受益于更快的数学运算吞吐量,而注意力层和逐元素层保持不变。在线性层之前额外添加的量化内核会引入一些开销。15%-25% 的加速幅度与我们对 vLLM 的独立测试相符。通过进一步优化,例如在 vLLM 中融合量化内核,我们预计加速比可以进一步提升至 1.25 倍。

MoE 模型上的 FP8 端到端:Qwen3-30B

在混合专家(MoE)模型上进行了类似实验,Qwen3-30B 的结果显示出相匹配的准确率曲线。FP8 实现了与 BF16 相似的准确率。速度提升仍在研究中。

Line chart of validation accuracy over 600 GRPO training steps for Qwen3-30B MoE model, where the end-to-end FP8 recipe closely matches the BF16 baseline, both converging around 0.65 accuracy.Line chart of validation accuracy over 600 GRPO training steps for Qwen3-30B MoE model, where the end-to-end FP8 recipe closely matches the BF16 baseline, both converging around 0.65 accuracy.
图 4. 在 8 个 H100 节点上,使用 OpenMathInstruct-2 数据集进行 Qwen3-30B GRPO 的准确率曲线。蓝色为 BF16,粉色为 FP8 端到端

将 FP8 扩展到 KV 缓存和 attention

在 transformer 模型中,线性层并不是唯一的瓶颈。在输出序列长度(OSL)较长的 RL 工作流中,KV cache 的增长和 attention 计算往往主导端到端 rollout 时间,同时还会使内存带宽饱和并减慢 token 生成速度。这促使我们探索在 RL 循环中将 FP8 用于 KV cache 和 attention。这里使用的是按张量缩放的 FP8。

在 RL 场景中为 KV-cache 实现 FP8 具有独特挑战,因为策略权重会在每一步发生变化。不同于只需校准一次的静态推理,RL 需要动态处理量化尺度。

NeMo RL 采用以下方法来解决这一问题:

  1. 重新校准:在每个训练步骤结束时,训练器会使用更新后的策略权重重新校准 Query、Key、Value(QKV)的尺度。
  2. 数据选择:该校准使用训练数据(提示词和生成的响应)执行,以确保缩放因子反映当前分布。
  3. 同步:随后将新计算出的缩放因子同步到推理引擎(vLLM),用于后续的 rollout 阶段。
Flowchart of the RL training workflow with FP8 KV cache, cycling through five stages: Refitting Phase (sync weights and QKV scales), Rollout Phase (vLLM generation with FP8 KV cache), Compute Rewards, Training Phase (update policy weights),Flowchart of the RL training workflow with FP8 KV cache, cycling through five stages: Refitting Phase (sync weights and QKV scales), Rollout Phase (vLLM generation with FP8 KV cache), Compute Rewards, Training Phase (update policy weights),
图 5. 使用 FP8 KV cache 的 RL 工作流

这种设计确保 rollout 引擎始终使用由最新策略状态得出的最优量化缩放因子,从而最大限度地减少精度下降。校准开销很小,约占总步进时间的 2-3%。

TensorType of dataScaling factor Type of scalingQKV attention activationsFP8 (E4M3)FP32Tensor-wiseStored KV cacheFP8 (E4M3)FP32Tensor-wise
表 3:注意力激活和存储的 KV cache 的张量格式

KV cache 和注意力使用 FP8 的结果摘要

我们使用 GRPO 算法在 Qwen3-8B-Base 模型上运行了实验,其中 rollout 阶段应用 FP8,训练阶段使用 BF16。由于误差累积,在同时量化 KV cache 和注意力时,不匹配 KL 散度略高,但我们的方案缓解了不稳定性。通过启用 token 级截断重要性采样,同时对线性层 + KV cache + 注意力使用 FP8,可实现与 BF16 基线以及线性层(W8A8)使用 FP8 相一致的验证准确率。

Four-panel chart of Qwen3-8B-Base training metrics (response length, AIME2024 validation accuracy, training-inference mismatch KL, and rewards) over 400 steps, where FP8 with KV cache and attention closely tracks the BF16 baseline across alFour-panel chart of Qwen3-8B-Base training metrics (response length, AIME2024 validation accuracy, training-inference mismatch KL, and rewards) over 400 steps, where FP8 with KV cache and attention closely tracks the BF16 baseline across al
图 6. Qwen3-8B-Base 的训练准确率曲线

为 KV-cache 和 attention 操作同时启用 FP8,相比线性 W8A8 配置,在 rollout 阶段额外带来约 30% 的加速,从而相比 BF16 基线总体实现约 48% 的加速。这些收益在更长的响应长度下尤为明显,因为此时 attention 计算在整体工作负载中占比更大。QKV scale 重新校准过程大约消耗总 step 时间的 2-3%,相对于所实现的显著加速而言,这一成本较小。

Five-panel chart of rollout performance metrics for Qwen3-8B-Base over 400 steps, where FP8 with KV cache and attention achieves ~48% speedup over BF16, with gains most visible in generation time and tokens per second at longer response lenFive-panel chart of rollout performance metrics for Qwen3-8B-Base over 400 steps, where FP8 with KV cache and attention achieves ~48% speedup over BF16, with gains most visible in generation time and tokens per second at longer response len
图 7. Qwen3-8B-Base 模型的 rollout 性能曲线

使用 NVIDIA NeMo RL 尝试端到端 FP8

为了在生成和训练后端的线性层中启用 FP8,以下配置映射展示了每个调优参数如何传递给训练和生成后端。

Diagram showing how NeMo RL configuration parameters map to the vLLM generation engine, Megatron training backend, and importance sampling settings when enabling end-to-end FP8 for linear layers.Diagram showing how NeMo RL configuration parameters map to the vLLM generation engine, Megatron training backend, and importance sampling settings when enabling end-to-end FP8 for linear layers.
图 8. 在 NVIDIA NeMo RL 中为线性层启用 FP8

要为 KV cache 和 attention 启用 FP8,需要在策略的 vllm_cfg 中配置 kv_cache_dtype 参数,该参数会在训练器端自动处理 QKV 缩放因子的重新校准,并与 vLLM 后端同步。

policy:
  generation:
    vllm_cfg:
      precision: fp8       # Enable FP8 for linear layers
      kv_cache_dtype: fp8  # Enable FP8 for KV-cache

用于生成和训练的高级 FP8 配置选项

到目前为止,我们已经介绍了线性层以及 KV cache + attention 层的 FP8 实现。高级用户可以尝试该方案的变体。以下是其中一些功能的示例:

  • 在生成过程中将前 N 个和/或后 M 个 transformer 层保持为 BF16(N、M 为整数)
policy:
  generation:
    vllm_cfg:
      num_first_layers_in_bf16: N # replace N with an integer
      num_last_layers_in_bf16: M  # replace M with an integer
  • 将生成和/或训练配置为使用 2 的幂缩放因子类型,而不是 FP32
policy:
  generation:
    vllm_cfg:
      pow2_weight_scaling_factors: true
      pow2_activation_scaling_factors: true
  megatron_cfg:
    env_vars:
      NVTE_FP8_BLOCK_SCALING_FP32_SCALES: "0"
  • 开发者可以使用为 Megatron Core 后端预定义的 FP8 配方变体,而不是默认的分块量化 FP8 配方,如表 1 所示。详情请参阅参数文档字符串。
policy:
  megatron_cfg:
    fp8_cfg:
      fp8: "e4m3"
      fp8_recipe: "blockwise"

开始使用

用户可以先参考 NeMo RL GitHub 中的 llama-3.1-8b 和 moonlight-16b 配方。

致谢

这项工作是多个团队协作完成的。我们感谢 Jimmy Zhang、Victor Cui、Zhiyu Li 和 Lark Zhang 在 FP8 配方开发、实验以及集成到 NeMo RL 方面所做的工作。

Like

标签

原文标题

Run High-Throughput Reinforcement Learning Training with End-to-End FP8 Precision