大模型训练 · 混合精度 · 约 12 分钟阅读

Megatron FP8

FP8 量化训练:从数据格式到大规模落地

显存减半
~1.5–2×吞吐量提升
E4M3 / E5M2两种 FP8 格式
H100FP8 Tensor Core

训练大型语言模型需要巨量算力和显存。Megatron-LM 的 FP8 训练将最耗带宽的矩阵乘法的权重和激活从 16 位压缩到 8 位,在 H100 GPU 上解锁最高 2 倍的吞吐量峰值——同时模型质量与 BF16 基线几乎相同。

§1 为什么要 FP8 训练?

现代 LLM 是带宽瓶颈型计算:GPU 大部分时间花在 HBM 与计算单元之间搬运数据,而非真正做运算。将数据类型从 BF16(2 字节)缩窄到 FP8(1 字节),对矩阵乘法这一 Transformer 训练的主导 kernel 而言,等效内存带宽翻倍。

🔴 传统 AMP

  • 权重:BF16 前向 + FP32 master
  • 梯度:FP32 累积
  • 激活:BF16
  • 每参数 6 字节

🟢 FP8 混合精度

  • 权重:前向 GEMM 中用 FP8 E4M3
  • 权重梯度:FP8 E5M2
  • 激活:计算时 FP8
  • master weights / 优化器仍 FP32
💡 核心思想: FP8 训练不是“把所有东西都换成 8 位”。只有最耗带宽的 GEMM 输入(权重和激活)才以 FP8 存储。LayerNorm、Softmax、残差连接和优化器仍保持 BF16/FP32。这种精准替换在最大化吞吐量的同时保证数值稳定性。

一个实用的心算法是把收益拆成两部分:持久模型状态与热路径张量流量。FP8 混合精度对已存储的权重状态节省其实较温和,因为 FP32 master copy 仍然存在;但它把每次 GEMM 输入搬运的字节数相对 BF16 再砍半。这也是为什么真实加速往往更像带宽收益,而不只是参数显存的线性节省。

# 粗略估算:权重存储副本与 GEMM 流量
# 例子:700 亿参数
params = 70_000_000_000

fp32_full_bytes = 4 * params          # 一份 FP32 权重副本
amp_bf16_bytes = (2 + 4) * params   # BF16 计算副本 + FP32 master
fp8_mixed_bytes = (1 + 4) * params  # FP8 计算副本 + FP32 master

saved_vs_amp_gib = (amp_bf16_bytes - fp8_mixed_bytes) / (1024 ** 3)
bf16_gemm_bytes = 2
fp8_gemm_bytes = 1
effective_bw_gain = bf16_gemm_bytes / fp8_gemm_bytes

print(saved_vs_amp_gib)     # ≈ 65.2 GiB 权重状态节省
print(effective_bw_gain)   # 同一总线下每秒元素吞吐提升 2.0 倍
方案 GEMM 权重字节数 Master 副本 近似权重存储字节/参数 实际影响
FP32 全精度 4 B 4 B 数值简单,但每次 GEMM 的字节搬运成本最高
AMP BF16 2 B FP32 master(4 B) 6 B 常见基线:流量低于 FP32,优化器状态稳定
FP8 混合精度 1 B FP32 master(4 B) 5 B 保留优化器安全性,同时 GEMM 输入密度相对 BF16 再提升 2 倍

▸ 互动演示:内存总线 — BF16 vs FP8

相同的物理内存总线,相同的总线带宽。BF16 每个元素占 2 字节;FP8 每个元素占 1 字节——因此 FP8 每秒可传输 2 倍数量的元素。观察数据包竞速:

running…

FP8 的硬件支撑是具体可量化的:NVIDIA H100 SXM 引入了第四代 Tensor Core,原生支持 FP8 计算。其峰值 FP8 矩阵乘法吞吐量为 3,958 TFLOPS,恰好是同一芯片 BF16 Tensor Core 1,979 TFLOPS 的两倍。H200 保留了相同的计算单元布局,同时将 HBM3e 带宽从 3.35 TB/s 提升至 4.8 TB/s,进一步放大了 FP8 的内存密度优势。正是这个 2× 的原始算力比,使 FP8 不仅仅是一种推理技巧,而是现代大规模训练的一等公民格式。

GPU FP8 Tensor TFLOPS BF16 Tensor TFLOPS FP8 / BF16 倍率 HBM 带宽
H100 SXM(80 GB HBM3) 3,958 TFLOPS 1,979 TFLOPS 2.00× 3.35 TB/s
H200 SXM(141 GB HBM3e) 3,958 TFLOPS 1,979 TFLOPS 2.00× 4.8 TB/s
A100 SXM(80 GB HBM2e)— 无原生 FP8 N/A(仿真) 312 TFLOPS 2.0 TB/s
💡 第四代 Tensor Core 关键细节: H100 的 Tensor Core tile 原生接受 FP8 输入并累加到 FP32,因此不存在软件仿真开销。2× 吞吐量增益来自每条 warp 指令中能装入两倍数量的 FP8 操作数(相比 BF16),而非时钟频率的变化。这也意味着 FP32 累加器的精度与 BF16 训练完全一致;只有乘法操作数变窄了。

§2 FP8 格式:E4M3 与 E5M2

8 个比特分为三个字段:符号位(S)、指数(E)和尾数(M)。FP8 标准定义了两种变体——点击下方按钮,动画展示比特是如何分配的。

E4M3FN

正常 FP8 数遵循 IEEE 754 编码规则:

$$\text{value} = (-1)^{S} \cdot 2^{E_{\text{int}} - \text{bias}} \cdot \left(1 + \frac{M_{\text{int}}}{2^{\#M}}\right)$$
格式 SE bitsM bits 偏置 最大值 典型用途
E4M3FN1437448 前向:权重、激活
E5M21521557344 反向:权重梯度
BF16187127~3.4×1038 传统 AMP 基线

比特布局会直接决定可用动态范围。若指数位有 k 位,则偏置通常为 2^(k-1)-1;最大正常指数来自最大的非保留指数编码;而最大有限值则是 (2 - 2^{-M}) · 2^{e_max}。这也是为什么多一位指数位带来的范围扩展,往往远大于多一位尾数位带来的收益。

# 从比特布局推导 bias 与最大有限值
def fp_format_stats(exp_bits, mant_bits):
    bias = (2 ** (exp_bits - 1)) - 1
    max_exp_code = (2 ** exp_bits) - 2   # 全 1 指数通常保留给特殊值
    e_max = max_exp_code - bias
    max_finite = (2 - 2 ** (-mant_bits)) * (2 ** e_max)
    return bias, e_max, max_finite

print(fp_format_stats(4, 3))  # E4M3 -> (7, 8, 448.0)
print(fp_format_stats(5, 2))  # E5M2 -> (15, 15, 57344.0)
格式 偏置 (bias) e_max 最大有限值 最小正规数
E4M3FN 7 8 448.0 2−6 ≈ 0.015625
E5M2 15 15 57344.0 2−14 ≈ 6.10×10−5
BF16 127 127 ~3.39×1038 2−126 ≈ 1.18×10−38

上表数值直接来自 fp_format_stats 的计算结果。注意 E5M2 的最小正规数远小于 E4M3(2−14 vs 2−6)——多出来的指数位让它在低端获得了更大的覆盖范围,这正是稀疏梯度分布所需要的。BF16 巨大的动态范围(bias = 127)使其自然成为主权重和优化器状态的首选格式。

上述 max_exp_code = 2**exp_bits - 2 中的 -2,是将全 1 指数编码(如 E4M3 的 1111)保留给特殊值。在标准 IEEE 754(以及 E5M2)中,全 1 指数 + 零尾数 = ±Inf,全 1 指数 + 非零尾数 = NaN。E4M3FN 则打破了这一惯例,以换取一个额外的有限值:全 1 指数 + 零尾数被重定义为 448(即最大有限值)而非 Inf,全 1 指数 + 非零尾数仍然是 NaN。实际影响是 E4M3FN 没有 Inf 表示——任何溢出都会饱和截断到 ±448,而不会流向 ±∞。E5M2 遵循标准 IEEE 惯例,同时保留 Inf 和 NaN。

💡 为何前向用 E4M3,反向用 E5M2? 前向激活和权重集中在小范围 → 更多尾数位(E4M3)提供更细粒度。反向梯度稀疏且峰值大 → 更多指数位(E5M2)防止溢出。

§3 可表示值的数轴分布

E4M3 有更多尾数位,其可表示值在 1.0 附近更密集。E5M2 则将值更均匀地分布在更宽的范围。下方可视化在对数坐标轴上展示所有正数可表示值——拖动目标值滑块,实时观察量化误差。

量化误差并不是均匀分布的。在同一个指数桶内,相邻 FP8 数之间的间距是均匀的,因此绝对间隔会随指数增大而增大。靠近 0 时,次正规数会填补一部分空隙,避免从最小正规数直接跳到 0;但由于可用的小刻度仍然很少,相对误差依旧很大。靠近范围上界时,最糟糕的问题不再是舍入,而是裁剪:一旦目标值超过 FP8_MAX,误差会从平滑增长变成突增。

E4M3FN
E5M2
x = 1.000
# 枚举所有非负 FP8 可表示值
def enumerate_fp8_values(exp_bits, mant_bits):
    bias = (2 ** (exp_bits - 1)) - 1
    values = []
    for e in range(2 ** exp_bits):
        for m in range(2 ** mant_bits):
            if e == 0 and m == 0:
                values.append(0.0)
            elif e == 0:
                subnormal = (m / (2 ** mant_bits)) * (2 ** (1 - bias))
                values.append(subnormal)
            elif e == (2 ** exp_bits) - 1:
                continue  # 跳过保留指数编码
            else:
                normal = (1 + m / (2 ** mant_bits)) * (2 ** (e - bias))
                values.append(normal)
    return sorted(set(values))

e4m3 = enumerate_fp8_values(4, 3)
print(e4m3[:12])   # 观察靠近 0 的细小刻度
print(e4m3[-8:])   # 观察最大的有限值
💡 次正规数很关键: 在 FP8 中,指数全 0 且尾数非 0 的值表示次正规数。它们在 0 与最小正规数之间提供渐进下溢,比把微小梯度直接冲成 0 更好。但它们的精度并不高:次正规数没有隐含的前导 1,因此数轴上相对误差最差的区域往往就出现在靠近 0 的地方。

为使数轴讨论更具体,下表列出四个代表性张量元素在缩放因子为 1.0 时经过 E4M3 量化后的结果。靠近 1.0 的小就绝对误差很小;接近或超过 FP8_MAX 的大就发生灘轴式裁剪。

输入倦小 最近 E4M3(s=1) 最近 E5M2(s=1) E4M3 绝对误差 E4M3 相对误差 说明
0.001 0.000977 0.000977 2.3×10−5 2.3 % 次正规区域;小将较粗
1.5 1.5 1.5 0.0 0 % 两种格式均可精确表示
100.0 100.0 96.0 0.0 (E4M3), 4.0 (E5M2) 0 % (E4M3), 4 % (E5M2) E4M3 在此处密度足够;E5M2 大小时较粗
400.0 400.0 384.0 0.0 (E4M3), 16.0 (E5M2) 0 % (E4M3), 4 % (E5M2) 接近 E4M3 最大値(448);安全余量将尽

§4 Tensor 缩放因子

FP8 的最大值只有 448(E4M3)——远小于 BF16 的 3.4×10³⁸。直接类型转换几乎会导致每个真实张量溢出。解决方案:为每个张量维护缩放因子 $s$,在转换前归一化。

$$x_{\text{fp8}} = \mathrm{cast\_fp8}\!\left(\frac{x}{s}\right), \qquad \hat{x} = x_{\text{fp8}} \cdot s, \qquad s^{*} = \frac{\mathrm{amax}(x)}{\text{fp8\_max}}$$

在不预留额外余量时,最优缩放因子 s* 就是把张量最大值映射到 FP8_MAX 的那个比例。若 s 比它更小,最大元素会被裁剪;若显著更大,又会让大量数值挤进少数粗糙刻度。实际系统通常先从 s* 出发,再按需要乘上一个 margin,为下一步的波动预留溢出余量。

import torch

def compute_fp8_scale(x, fp8_max=448.0, margin=0):
    amax = x.abs().max().clamp_min(1e-12)
    s_star = amax / fp8_max                     # 不留余量时的最优缩放因子
    s = s_star * (2.0 ** margin)       # 预留 2^margin 的安全余量
    x_scaled = x / s
    x_clipped = x_scaled.clamp(-fp8_max, fp8_max)
    return s, s_star, x_clipped

x = torch.tensor([-920.0, 3.5, 511.0], dtype=torch.float32)
s, s_star, x_q = compute_fp8_scale(x, margin=1)
print(s_star, s)
print(x_q)
策略 何时计算缩放 延迟代价 对尖峰的鲁棒性 典型用途
即时缩放 在 cast 前立即计算当前 amax 最高:需要额外张量扫描与同步点 对当前 batch 的数值适配最好 调试、离线分析、很小的张量
延迟缩放 使用前几步的 amax 历史 低:元数据更新可摊销,无需 GEMM 前扫描 整体较好,但一步滞后可能裁剪突发尖峰 Megatron / Transformer Engine 的默认训练路径
静态缩放 通过校准或经验规则一次性选定 运行时最低 最弱:分布漂移时很容易失效 推理阶段或分布高度可控的工作负载
💡 Margin 参数: 正的 margin 会故意让 s 比数学上最紧的 s* 更大,也就是在转成 FP8 之前先把张量再缩小一点。你今天会牺牲一部分尾数分辨率,换来对下一步更大 amax 的安全余量。margin 过大本质上就是过度缩放:更安全,但也更浪费,因为更多数值会塌缩到同一批 FP8 刻度上。

下方互动演示生成一个随机 BF16 张量,其值范围远超 FP8 所允许的范围。切换原始值/缩放后视图,了解缩放因子的必要性。

Before scaling

在其内部 fp8_meta 管理层以 C++/CUDA 实现缩放计算。amax 归约通过专用的融合 kernel 完成,可以通过 cublasLtMatmul epilogue 回调携带,也可为独立归约 kernel。amax 以 FP32 标量存入 fp8_meta 字典,逆缩放因子 scale_inv = 1/s 被预先计算并缓存,让反量化只需一次元素乘法。

💡 TE 中的 fp8_meta 布局: Transformer Engine 为每个开启 FP8 的模块维护一个 fp8_meta 字典。每个条目包含三个 FP32 张量:scale(形状 [num_gemms])、scale_inv(形状 [num_gemms])和 amax_history(形状 [history_len, num_gemms])。预先计算 scale_inv 意味着反量化可以在读取 FP8 输出的同一 kernel 内完成融合乘法。

§5 延迟缩放:复用历史 amax

在每次 GEMM 前计算 amax 需要完整扫描张量——增加了延迟和同步开销。Megatron-LM 通过延迟缩放绕过了这个问题:为每个张量维护一个包含最近 W 个 amax 值的循环缓冲区,然后从缓冲区的最大值推导当前缩放因子。开销几乎为零,并且缩放值在各步骤间保持数值稳定。

⚡ 为何称为「延迟」?

在步骤 t,缩放因子 st 是根据步骤 t−W … t−1 的 amax 历史计算得出的。当前步骤的 amax 在 GEMM 之后才被记录——它只从步骤 t+1 开始影响缩放因子。这个一步的滞后就是「延迟」。

🔁 4 阶段循环(每次 GEMM 调用)

  1. 读取缩放因子 s = max(history) / fp8_max
  2. 转换 x → FP8(使用 s)
  3. GEMM 在 Tensor Core 上以 FP8 执行
  4. 记录 amax(x) 到历史缓冲区

在实现层面,延迟缩放本质上就是附着在每条 FP8 张量流上的一个小状态机:先读取历史窗口中的最大值,为本次调用推导缩放因子,执行 GEMM,然后用新观测到的 amax 覆盖循环缓冲区中的一个槽位。关键细节在于:缓冲区更新发生在 kernel 之后,因此当前步骤无法“回头”修正自己这一步的溢出风险。

# 延迟缩放的循环缓冲区伪代码
class AmaxHistory:
    def __init__(self, W, fp8_max):
        self.buf = [0.0] * W
        self.ptr = 0
        self.fp8_max = fp8_max

    def current_scale(self):
        hist_amax = max(self.buf)
        return hist_amax / self.fp8_max if hist_amax > 0 else 1.0

    def run_gemm(self, x, w):
        scale = self.current_scale()
        x_fp8 = quantize_to_fp8(x, scale)
        w_fp8 = quantize_to_fp8(w, scale)
        y = fp8_gemm(x_fp8, w_fp8, out_dtype="bf16")

        # GEMM 结束后才记录当前步的 amax
        self.buf[self.ptr] = amax(x)
        self.ptr = (self.ptr + 1) % len(self.buf)
        return y
历史窗口 W 缩放行为 主要优势 主要代价 典型用法
W = 1 只使用最近一次 amax;响应几乎是即时的。 能快速跟踪分布漂移;几乎没有陈旧历史。 缩放因子抖动更大;当激活随 batch 波动时更不稳定。 快速变化的小规模实验,或用于调试灵敏度。
W = 16 使用较短滚动历史;在响应速度与平滑性之间折中。 常见的 bring-up 默认值;既能记住近期尖峰,又不会引入太大滞后。 对更长时间尺度的漂移和周期性罕见离群值仍然不够稳。 初始 FP8 实验,以及中等规模训练。
W = 1024 使用很长的保守历史;缩放因子更容易被罕见的大 amax 主导。 缩放因子非常稳定,在长训练中对溢出保护更强。 当分布收缩后适应较慢,小值会更长时间地损失尾数精度。 大规模 GPT 训练,优先保证稳定性而不是瞬时自适应。

一个很实用的心智模型是:较小的 W 让 delayed scaling 更像高增益控制器,而较大的 W 则更像保守的低通滤波器。Megatron 通常更偏向保守一侧,因为偶尔一步“过度缩放”的代价,通常小于反复发生 FP8 饱和带来的代价。

下方动画逐步展示 3 次训练迭代,显示历史缓冲区状态、推导的缩放因子和 GEMM 执行。「尖峰」场景(梯度突然激增)演示了延迟缩放能处理和不能处理的情况。

step 0 / phase A
⚠️ 尖峰场景: 当步骤 t 发生梯度尖峰时,步骤 t 的缩放因子是根据 history[t−W : t−1] 计算的,可能太小 → 值裁剪到 FP8_MAX。但步骤 t 的大 amax 立即被添加到历史缓冲区,因此到步骤 t+1 时缩放因子已经修正。一步的裁剪通常是可接受的——优化器层面的梯度裁剪会处理剩余影响。
step 0
💡 amax_compute_algo='max': Megatron 取历史窗口中 amax 值的最大值——保守策略,永不低估范围,防止溢出,代价是偶尔过度缩放。

在 Transformer Engine 的 Python 层,延迟缩放状态存储在 fp8_meta 中:amax_history(形状 [history_len, num_gemms])和 scale(形状 [num_gemms])。每次前向传播完成后,新测量的 amax 写入循环缓冲区的 step % history_len 位置;下次 GEMM 前从 amax_history.max(dim=0) 重新推导 scalescale_inv

# Transformer Engine fp8_meta 延迟缩放更新 — PyTorch 伪代码
import torch

# 状态张量(FP32,存活在 fp8_meta 中)
history_len = 16            # 由 amax_history_len 配置
num_gemms   = 3             # 前向权重、前向激活、反向梯度
amax_history = torch.zeros(history_len, num_gemms)  # [history_len, num_gemms]
scale        = torch.ones(num_gemms)               # [num_gemms]
scale_inv    = torch.ones(num_gemms)               # [num_gemms]  = 1 / scale
step         = 0

def update_fp8_meta(new_amax: torch.Tensor, fp8_max: float = 448.0,
                        algo: str = 'max', margin: int = 0):
    global amax_history, scale, scale_inv, step
    # 1. 将新 amax 写入循环缓冲区槽位
    slot = step % history_len
    amax_history[slot] = new_amax

    # 2. 从历史记录中计算有效 amax
    if algo == 'max':
        eff_amax = amax_history.max(dim=0).values  # 保守策略
    else:   # 'most_recent'
        eff_amax = amax_history[slot]                  # 响应式策略

    # 3. 推导新的 FP32 scale 并预计算逆值
    scale     = (eff_amax / fp8_max) * (2.0 ** margin)
    scale_inv = 1.0 / scale.clamp_min(1e-12)  # 缓存,供反量化时使用
    step += 1

§6 前向与反向 GEMM

每个线性层执行三次矩阵乘法。FP8 训练将操作数替换为适当类型的 FP8 张量。点击前向/反向查看各自流程,或点击动画播放。

$$\underbrace{Y = X\,W^T}_{\substack{\text{前向}\\\text{(E4M3×E4M3)}}} \qquad \underbrace{dX = dY\,W}_{\substack{\text{输入梯度}\\\text{(E5M2×E4M3)}}} \qquad \underbrace{dW = dY^T X}_{\substack{\text{权重梯度}\\\text{(E5M2×E4M3)}}}$$

这三个 GEMM 在线性代数形式上相同,但数值风险画像并不一样。前向的激活和权重通常在归一化后分布较温和,因此可以安全地放在 E4M3;而反向路径中的梯度张量则更依赖 E5M2,以换取更大的指数动态范围。累加器以及对外暴露的输出仍保持在 BF16/FP32,从而避免量化误差沿归约维度持续累积。

GEMM 公式 FP8 操作数类型 典型形状 累加 / 输出精度
前向 Y = X W^T X:E4M3,W:E4M3 [M, K] × [N, K]^T → [M, N] FP32 累加,BF16/FP32 输出
输入梯度 dX = dY W dY:E5M2,W:E4M3 [M, N] × [N, K] → [M, K] FP32 累加,BF16/FP32 输出
权重梯度 dW = dY^T X dY^T:E5M2,X:E4M3 [N, M] × [M, K] → [N, K] FP32 累加,通常以 BF16/FP32 梯度形式输出
# 简化版 Transformer Engine GEMM 调度
def te_linear(x, w, fp8_meta, grad_output=None):
    x_fp8 = cast_to_fp8(x, fp8_meta.x_scale, format="E4M3")
    w_fp8 = cast_to_fp8(w, fp8_meta.w_scale, format="E4M3")

    if grad_output is None:
        return cublas_fp8_gemm(
            x_fp8, w_fp8, accumulate_dtype="fp32", out_dtype="bf16"
        )

    dy_fp8 = cast_to_fp8(grad_output, fp8_meta.dy_scale, format="E5M2")
    dx = cublas_fp8_gemm(dy_fp8, w_fp8, accumulate_dtype="fp32", out_dtype="bf16")
    dw = cublas_fp8_gemm(transpose(dy_fp8), x_fp8, accumulate_dtype="fp32", out_dtype="fp32")
    return dx, dw
Forward pass
💡 累加器保持 BF16/FP32: 只有 GEMM 的输入是 FP8。输出累加器保持高精度以防止误差累积——H100 的 FP8 Tensor Core 内部以 FP32 进行累加。

cuBLAS 通过 cublasLtMatmul 发起 FP8 GEMM,在 Ampere/Hopper 上把操作数布局设置为 CUBLASLT_ORDER_COL32_2R_4R4——这种格式将数据按 32 元素列主序分条打包,专门针对硬件 warp 级显存子系统优化。FP8 路径还通过 cublasLtMatmulDescSetAttribute 暴露了 epilog 融合 机制:偏置加法、ReLU、GELU 或自定义逐点函数可以直接融合进 GEMM epilog 阶段,省去单独的 kernel 启动以及中间 BF16 结果的全局显存往返。Transformer Engine 在 MLP 投影的 bias+GELU 融合中正是利用了这一机制。

cuBLAS 针对 FP8 与 BF16 会选取不同的 tile 大小。BF16 常用 128×128 或 256×128;FP8 因数据类型更窄、寄存器预算更充足,可选用 256×256 的更大 tile,从而提升算术强度。Hopper 上全部三次 FP8 GEMM 的底层硬件指令是 wgmma(warp-group matrix multiply-accumulate)——它原生支持 64×8×16 或 64×16×16 的 FP8 tile,并以 FP32 累加。wgmma 在 warp group 共享显存流水线上持续执行,矩阵足够大时数据移动与算术运算几乎完全重叠而无额外开销。

§7 Megatron-LM API

Megatron-LM 通过 NVIDIA transformer_engine 库提供 FP8 支持,只需几个参数:

# Launch script flags
--fp8-format hybrid           # E4M3 fwd + E5M2 bwd
--fp8-amax-history-len 1024    # rolling amax window (default; reduce to 16 for bring-up)
--fp8-amax-compute-algo max   # conservative: take window max

# Python API
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling

fp8_recipe = DelayedScaling(
    fp8_format=Format.HYBRID,   # E4M3 fwd / E5M2 bwd
    amax_history_len=16,
    amax_compute_algo="max",
)
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    output = model(input_ids, attention_mask)
参数 类型 / 可选值 默认值 含义 何时调整
--fp8-format {hybrid, e4m3} hybrid 选择 FP8 数值格式家族;hybrid 表示前向友好张量走 E4M3,反向敏感张量走 E5M2。 只有在做对照实验、希望统一格式时才考虑改成 e4m3
--fp8-recipe {delayed, tensorwise, blockwise} delayed 选择缩放 recipe 家族:历史延迟缩放、当前步按张量缩放,或在新硬件上的块级缩放。 当 delayed scaling 响应过慢,或迁移到 Blackwell/MXFP8 时再切换。
--fp8-amax-history-len 正整数 1024 控制 delayed scaling 在循环缓冲区中保留多少个历史 amax 样本。 bring-up 时可降到 16;若更在意长程稳定性而非快速适应,则增大。
--fp8-amax-compute-algo {max, most_recent} max 定义如何从历史窗口中得到当前 amax;max 是更保守的默认选择。 只有在需要更快响应且能接受更大缩放抖动时,才尝试 most_recent
--fp8-margin 整数,≥ 0 0 为指数范围预留安全余量,使缩放因子对溢出更保守。 只有在反复溢出时才提高;margin 过大则会浪费尾数精度。
--fp8-interval 正整数 1 控制 FP8 元数据更新频率;1 表示每个训练步都更新。 只有在实验性地牲牲自适应性、换取更少元数据更新时才增大。
--fp8-wgrad 开关标志(存在即开启) 默认关闭 当后端路径支持时,使权重梯度 GEMM 也能够使用 FP8。 当基础 FP8 跑法已经稳定、希望榨取最后一点 GEMM 吞吐时再开启。

上表标志直接对应 transformer_engine.common.recipe.DelayedScaling 的构造函数参数:--fp8-formatfp8_format--fp8-amax-history-lenamax_history_len,以此类推。这些参数已在 Transformer Engine ≥ 1.7、CUDA ≥ 12.1、cuBLAS ≥ 12.1 的组合下验证。完整最新的 API 文档请参阅 transformer-engine.readthedocs.io

一个稳妥的训练中途 BF16→FP8 迁移流程通常是:先用 BF16 训练到 loss 和优化器动量稳定,保存检查点,再带上 --fp8-format hybrid 及相应 recipe 参数恢复训练,并把最初的 100–500 步视为校准阶段。在这段预热期内应继续开启梯度裁剪,观察是否出现一次性的 loss spike,并避免同时改动太多超参数——因为权重本身已经学到有用结构,FP8 主要需要的是重建新的 amax 历史,而不是从头重新学习模型。

💡 张量并行: FP8 与 Megatron TP/PP 无缝配合。激活 amax 通过 all-reduce 在 TP 组内同步,确保所有 GPU 共享一致的缩放因子。不支持 FP8 的操作(如 FlashAttention)自动降级为 BF16。

§8 性能基准

点击 ▶ 动画观察吞吐量柱状图增长,悬停查看详细数字。数据来自官方 Megatron-LM 报告,H100 SXM 80GB,TP=8。

一个公平的 FP8 基准应当固定模型结构、序列长度、全局 batch size、张量/流水并行配置以及优化器设置,只比较算术路径本身。在官方 Megatron 设置中,吞吐量是在 H100 SXM 80GB、Tensor Parallelism = 8 的系统上报告的;下表反映的是 FP8 初始校准窗口结束后的稳定训练阶段,而不是刚切换时那几步噪声较大的 bring-up 迭代。

模型规模 BF16 吞吐(TFLOPS) FP8 吞吐(TFLOPS) 加速比
GPT-3 126M 183 239 1.31×
GPT-3 5B 274 365 1.33×
GPT-3 175B 312 432 1.38×

可复现元数据: 硬件:8× NVIDIA H100 SXM 80 GB。软件栈:CUDA 12.2、cuBLAS 12.2、Transformer Engine 1.7、Megatron-LM 主分支(2024-Q1)。每行模型的 TP/PP 配置:126 M — TP=1, PP=1;5 B — TP=2, PP=1;175 B — TP=8, PP=4。所有实验均使用序列长度 2048、全局 batch size 256,BF16 主权重 + FP8 GEMM 路径。TFLOPS 为单 GPU 属性,取 FP8 校准预热后第 100–500 步的稳态平均值。计算公式:TFLOPS = (6 × 参数量 × seq_len × batch_size) / (step_time × 1012)。

# 复现基准:GPT-3 5B 在 8×H100 上运行(TP=2,PP=1)
torchrun --nproc_per_node=8 pretrain_gpt.py \
    --tensor-model-parallel-size 2 \
    --pipeline-model-parallel-size 1 \
    --num-layers 24 --hidden-size 4096 --num-attention-heads 32 \
    --seq-length 2048 --micro-batch-size 4 --global-batch-size 256 \
    --fp8-format hybrid --fp8-amax-history-len 1024 --fp8-amax-compute-algo max \
    --train-iters 500 --log-interval 10 \
    --bf16  # 主权重用 BF16;GEMM 输入用 FP8

对上述数字的几点说明:(1)所有数字均为近似值,会随驱动版本、batch size 和序列长度不同而有所变化。(2)对于 126M 等小模型,每层 FP8 缩放元数据的开销相对于计算量占比较大,因此加速效果较高收益(~1.31×)。(3)对于 175B 等大模型,任务几乎完全是计算密集型,H100 Tensor Core 对 FP8 2×吸射利率优势被充分发挥,相比 BF16 可获得近 1.4× 的持续吸射率提升。

# Megatron FP8 训练启动命令(H100 / TP=8 示例)
torchrun --nproc_per_node=8 pretrain_gpt.py \
    --tensor-model-parallel-size 8 \
    --pipeline-model-parallel-size 1 \
    --sequence-parallel \
    --bf16 \
    --fp8-format hybrid \
    --fp8-recipe delayed \
    --fp8-amax-history-len 1024 \
    --fp8-amax-compute-algo max \
    --fp8-margin 0 \
    --fp8-wgrad
悬停查看详情
📊 模型质量: 在 Pile/MMLU/HellaSwag 上,FP8 训练的 loss 曲线和任务指标与 BF16 差异在 0.1% 以内。梯度裁剪阈值可略微放宽(1.0→1.5)以适应梯度中的量化噪声。

§9 Transformer Engine te.Linear 内部机制

在 Megatron-Core 内部,启用 FP8 并不是简单地给普通 nn.Linear 外面套一层 cast。模型是通过 ModuleSpec 规格树构建的,规格会把标准线性层替换成 Transformer Engine 包装器,例如 TEColumnParallelLinearTERowParallelLinearTELayerNormColumnParallelLinear。这些包装器一方面保留 Megatron 的张量并行切分、参数初始化与通信契约,另一方面把真正的 FP8 GEMM 委托给 te.Linear

🧩 规格层替换

Megatron 在模型构建阶段就选定 TE 模块,而不是在运行时对 nn.Linear 做模式匹配替换。这一点很关键,因为 QKV 投影、MLP 上投影和输出投影分别需要不同的张量并行切分语义。

⚙️ 包装器职责

  • 列并行:在 TP rank 之间切分输出特征
  • 行并行:切分输入特征并归约部分和
  • LayerNorm+Linear:把归一化元数据与 FP8 感知投影融合
  • TE 后端:管理 FP8 元数据、缩放因子与 GEMM 调度
# Megatron's transformer spec with TE modules
from megatron.core.extensions.transformer_engine import (
    TEColumnParallelLinear,
    TERowParallelLinear,
    TELayerNormColumnParallelLinear,
    TEDotProductAttention,
    TENorm,
)

# MLP spec: replaces nn.Linear with TE FP8-aware layers
mlp_spec = MLPSubmodules(
    linear_fc1=TEColumnParallelLinear,  # TP column-parallel + FP8
    linear_fc2=TERowParallelLinear,     # TP row-parallel + FP8
)

# Attention spec
attn_spec = SelfAttentionSubmodules(
    linear_qkv=TEColumnParallelLinear,
    core_attention=TEDotProductAttention,
    linear_proj=TERowParallelLinear,
)

梯度累积阶段有一个很关键但常被忽略的优化。多个 microbatch 在优化器更新之前会复用完全相同的权重,如果每个 microbatch 都重新把权重量化成 FP8,就会白白浪费带宽和 kernel。于是 Transformer Engine 会在第一个 microbatch 后缓存 FP8 化的权重。Megatron 将 is_first_microbatch 传给 TE 包装器:True 表示重新量化并刷新 FP8 缓存,False 表示复用缓存的 FP8 权重,None 表示关闭缓存路径、每次都重新量化。

前向步骤 TE 内部动作
1. 读取 recipe 状态 读取当前激活与权重对应的 scale、inverse scale 以及 amax 历史元数据。
2. 量化权重或命中缓存 如果是第一个 microbatch,就把 BF16 master weight 转成 FP8 并缓存;否则直接复用缓存的 FP8 副本。
3. 量化输入激活 使用当前调用对应的激活缩放因子,把输入激活张量转换为 FP8。
4. 启动 cuBLAS FP8 GEMM 在 FP8 Tensor Core 上执行矩阵乘法,同时保持累加器为更高精度。
5. 返回 BF16/FP32 输出 GEMM 结果会以 BF16/FP32 友好的形式返回给 Megatron 计算图,供后续算子继续使用。
6. 记录当前 amax kernel 路径会回写本步观测到的 amax,供 recipe 更新后续步骤的缩放决策。
读取 scalerecipe 状态 / 历史
权重 FP8 转换或命中缓存
激活 FP8 转换按调用缩放
cuBLAS FP8 GEMMFP32 累加
写回 amax供下步使用
# FP8 weight caching across microbatches
for micro_idx in range(num_microbatches):
    is_first = (micro_idx == 0)
    with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
        output = model(input, is_first_microbatch=is_first)
        # micro_idx=0: quantize weight → FP8, cache it
# micro_idx>0: reuse cached FP8 weight (skip quantization)
    loss = criterion(output, target)
    loss.backward()
# optimizer.step() after all microbatches

fp8_meta 本质上是每个模块各自维护的运行时缓存,而不是需要长期保存的检查点状态。Transformer Engine 会在第一次 FP8 前向时懒分配它,在每个 optimizer step 的第一个 microbatch 刷新权重侧缓存,在后续 microbatch 复用已经量化好的 FP8 张量;一旦参数更新、模块重载或从检查点恢复,这份缓存就会重新构建。真正持久的“真相源”仍然是 BF16/FP32 master weight,fp8_meta 只是围绕它动态再生成的临时状态。

# Inspect fp8_meta on a TE module
for name, mod in model.named_modules():
    if hasattr(mod, 'fp8_meta'):
        print(name)
        print(mod.fp8_meta['scaling_fwd'].amax_history)
        print(mod.fp8_meta['scaling_fwd'].scale)
        break
💡 为什么权重缓存很重要: 在大规模 GPT 训练中,一个 optimizer step 往往包含几十个 microbatch。如果每次都把同一份 GB 级权重分片重新量化成 FP8,会吃掉大量显存带宽,抵消一部分 FP8 提速收益。复用缓存后的 FP8 权重意味着只有第一个 microbatch 付出转换代价,后续 microbatch 基本只承担 GEMM 本身的成本。

在同一个模型的多个 Transformer Engine 模块之间协调 FP8 状态,是 TE Python 层内部单例 FP8GlobalStateManager 的职责。当你进入 te.fp8_autocast(enabled=True, fp8_recipe=recipe) 上下文时,这个上下文管理器会把当前 recipe 和 FP8 启用标志注册到全局状态管理器,从而让模型计算图中任意位置的 te.Linearte.LayerNormte.MultiheadAttention 都能读取到同一份权威状态,而无需通过构造函数参数显式传递。这对 Megatron 的 ModuleSpec 模型尤为关键,因为 TE 包装器深埋在 TransformerLayer 内部,无法直接访问外层训练循环的局部变量。

FP8GlobalStateManager 还能优雅处理嵌套的 fp8_autocast 上下文:内层上下文会递增一个引用计数,并继承外层 recipe,除非明确指定了不同的;内层退出时计数器递减并恢复外层 recipe。这使得流水线并行阶段内部的 TE 模块能正确工作,即便 Megatron 的流水线调度把每个 microbatch 的前向调用套在各自的上下文中。该管理器还暴露了若干类方法,可用于调试查询全局状态。

# 运行时查询 FP8GlobalStateManager
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager

with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
    # 当前上下文中 FP8 是否激活?
    print(FP8GlobalStateManager.is_fp8_enabled())   # True
    # 获取当前 recipe
    print(FP8GlobalStateManager.get_fp8_recipe())   # DelayedScaling(...)
    # 嵌套深度(便于流水线阶段调试)
    print(FP8GlobalStateManager.fp8_autocast_depth())  # 1
    with te.fp8_autocast(enabled=True):
        print(FP8GlobalStateManager.fp8_autocast_depth())  # 2
    print(FP8GlobalStateManager.fp8_autocast_depth())  # 退出内层后恢复 1

§10 缩放策略演进:Delayed → Current → Block (MXFP8)

Transformer Engine 的 FP8 缩放策略已经从单一方案演进为一组精度逐步提高的方法。DelayedScaling 是 H100 时代最常用的基础方案:它复用前几步的 amax 统计,几乎没有额外运行时开销。Float8CurrentScaling 则会在 GEMM 前扫描当前张量,消除一步滞后,但要额外做一次归约。到了 Blackwell,MXFP8BlockScaling 把粒度继续细化:每连续 32 个元素共享一个 E8M0 缩放因子,这正是 OCP 规范中的 Microscaling(MX)格式。

特性 DelayedScaling Float8CurrentScaling MXFP8BlockScaling
缩放因子计算 来自历史缓冲区 当前步即时扫描 每 32 元素一个 block
精度 较好(有 1 步滞后) 更好(基于当前张量精确计算) 最好(细粒度)
开销 近乎为零 一次完整张量扫描 块级运算
硬件要求 H100+ H100+ Blackwell+
历史缓冲区 是(W 个条目)

⏱️ Delayed 与 Current

DelayedScaling 提供 fp8_formatamax_history_lenamax_compute_algomarginfp8_mha 等参数。Float8CurrentScaling 保留 format 与 margin 等控制项,也支持 fp8_mha,但不再需要历史缓冲区,因为缩放值直接来自当前张量。

🧱 MXFP8 的意义

当一个张量同时包含平稳区域和局部异常值时,per-tensor scaling 仍然可能过粗。MXFP8 为每 32 个值分配一个缩放因子,因此某个局部尖峰不再迫使整个张量采用保守缩放。这也是 Blackwell 在长序列注意力和超大 MLP 激活上能保留更多 FP8 精度的原因。

from transformer_engine.common.recipe import (
    DelayedScaling, Float8CurrentScaling, MXFP8BlockScaling, Format
)

# Strategy 1: Delayed (original, lowest overhead)
delayed = DelayedScaling(
    fp8_format=Format.HYBRID,
    amax_history_len=16,
    amax_compute_algo="max",
    margin=0,
)

# Strategy 2: Current (more accurate, slightly more overhead)
current = Float8CurrentScaling(
    fp8_format=Format.HYBRID,
    margin=0,
)

# Strategy 3: MXFP8 Block (finest granularity, Blackwell only)
mxfp8 = MXFP8BlockScaling(
    fp8_format=Format.HYBRID,
)
# Pseudocode: aggregate amax across TP ranks before choosing a tensorwise scale
local_amax = amax(abs(tensor_shard))
global_amax = local_amax.clone()
dist.all_reduce(global_amax, op=dist.ReduceOp.MAX, group=tp_group)
scale = fp8_max / clamp(global_amax, min=eps)
fp8_tensor = cast_to_fp8(tensor_shard, scale)

一个看起来容易让人困惑的参数是 margin。它的本质是预留安全余量,主动缩小可用的 FP8 动态范围:TE 不再把张量缩放到绝对的 FP8 上限,而是近似按 scale = amax / (fp8_max / 2^margin) 来计算。更大的 margin 会降低溢出概率,但也会牺牲一部分尾数分辨率,因为数值只占据可表示范围的一小部分。

对按张量的 delayed/current scaling 来说,被同步的数据载荷很小,但对时延非常敏感:每个 amax 通常只是一个 FP32 标量(4 字节)。因此在 P 个 TP rank 上做一次 ring all-reduce,每个 rank 的链路流量大致是 2 * (P - 1) / P * 4B;当 P=8 时,原始载荷只有约 7 字节,所以真正先成为瓶颈的通常不是带宽,而是反复发起 all-reduce 的启动开销与同步时延。

💡 margin 是安全旋钮,不是免费午餐:margin 从 0 调到 2,等价于为张量额外预留 4× 的防溢出空间,但量化器也会因此变得更保守。实践中,稳定训练通常先保持 0;只有当你观察到持续 FP8 溢出或某些层反复出现 loss spike 时,再把它调高。

MXFP8BlockScaling 背后,用于表示每个 block 共享缩放因子的是 E8M0 格式。E8M0 是一种 8 位纯指数浮点格式:它没有尾数位也没有符号位,因此每个可表示值都是 2 的正整次幂,范围从 2-127 到 2127。由于恰好 256 个不同的可表示值覆盖了 255 个连续的 2 的幂次,E8M0 可以以单个无符号字节存储,应用时不需任何浮点运算——乘以一个 E8M0 缩放因子等价于把数据元素的有偏差指数字段加上一个常量。正是这种简洁性使得硬件加速的分块缩放在不增加逐元素乘法器的前提下成为可能。

根据 OCP MX 规范,每个 E8M0 缩放因子恰好覆盖 32 个连续元素。这意味着:对于一个 N 字节的 FP8 一维数据缓冲区,配套的 E8M0 缩放因子缓冲区占 N/32 字节,元数据开销恰好为 1 字节 / 32 数据字节,即 3.125%。在模型尺度下这可以忽略不计:一个 16K×16K 的 FP8 权重矩阵占数据 256 MB,配套的 E8M0 缩放因子仅占 8 MB。固定 32 元素的块大小同时也是硬件缩放应用单元能以编译期确定的融合指令执行而非可配置 kernel 参数的关键。

§11 FP8 与分布式训练:TP/PP/CP 集成

只有在 Megatron 的分布式栈中跑通之后,FP8 才真正有价值。在张量并行下,每个 rank 只持有逻辑张量的一部分,因此本地 amax 不足以推导全局一致的缩放因子。Megatron 会先在 TP 组内对 amax 元数据执行 all-reduce(max),再进行 FP8 转换,从而保证所有 rank 使用相同的 scale,得到数值兼容的 FP8 分片。

GPU 0:TP 分片本地张量切片
all-reduce(amax, max)在 TP 组内
共享 scale一致的转换规则
FP8 GEMM本地 Tensor Core 矩阵乘
all-reduce(output, sum, BF16)行并行结果合并

📡 TP 中的 FP8 通信

Megatron 对 all-reduce、all-gather 这类 collective 通信默认仍采用 BF16,因为这是最稳妥、最通用的路径。但 Transformer Engine 可以在张量并行线性层中用 FP8 对前向激活做 all-gather,把通信量大致减半。这一优化被封装在 TEColumnParallelLinearTERowParallelLinear 内部,模型代码通常不需要额外分支。

🧱 PP 与 CP 边界

跨 pipeline stage 发送激活时,通常会先把 FP8 张量上转换回 BF16,因为不同 stage 可能维护不同的 FP8 缩放元数据和 recipe 状态。Context parallelism 也能与 FP8 共存,但如果未开启 FP8 attention,或当前后端没有为该序列切分提供 FP8 路径,注意力 kernel 会回退到 BF16。

# Pseudo-flow for TP-aware FP8 casting
local_amax = tensor_amax(input_shard)
dist.all_reduce(local_amax, op=dist.ReduceOp.MAX, group=tp_group)
shared_scale = compute_scale(local_amax)
fp8_input = cast_to_fp8(input_shard, shared_scale)
fp8_out = fp8_gemm(fp8_input, fp8_weight)

# Optional TP forward communication in FP8
gathered = tp_all_gather_fp8(fp8_out)

# PP boundary: convert back before send
pp_tensor = to_bf16(gathered)
send_to_next_pipeline_stage(pp_tensor)
import torch
import torch.distributed as dist
from megatron.core import parallel_state

# ── Option A: Megatron parallel_state (preferred in Megatron codebase) ──
tp_group = parallel_state.get_tensor_model_parallel_group()

# ── Option B: manual group (standalone / debugging) ──
# tp_group = torch.distributed.new_group(ranks=[0, 1, 2, 3])

def sync_amax(input_shard, tp_group, fp8_max):
    # NCCL requires GPU tensors — .cuda().float() is mandatory
    local_amax = torch.amax(input_shard.abs()).float().cuda()
    global_amax = local_amax.clone()

    # Non-blocking all_reduce: overlaps with next host-side computation
    handle = dist.all_reduce(
        global_amax,
        op=dist.ReduceOp.MAX,
        group=tp_group,
        async_op=True,
    )
    # ... interleave CPU/host work here if needed ...
    handle.wait()  # synchronize before using global_amax

    scale = fp8_max / global_amax.clamp(min=1e-12)  # guard against zero amax
    return global_amax, scale

# ── Pack multiple amax scalars → one all_reduce (3× fewer collectives) ──
# Instead of three separate all_reduce calls:
#   sync_amax(act, ...)   # all_reduce #1
#   sync_amax(wt, ...)    # all_reduce #2
#   sync_amax(grad, ...)  # all_reduce #3
# Pack them into a single tensor and do one collective:
all_amax = torch.stack([
    torch.amax(act.abs()).float().cuda(),
    torch.amax(wt.abs()).float().cuda(),
    torch.amax(grad.abs()).float().cuda(),
])
dist.all_reduce(all_amax, op=dist.ReduceOp.MAX, group=tp_group)
amax_act, amax_wt, amax_grad = all_amax[0], all_amax[1], all_amax[2]

global_amax, shared_scale = sync_amax(input_shard, tp_group, fp8_max=448.0)
fp8_input = cast_to_fp8(input_shard, shared_scale)
# Per-rank amax consistency check inside the TP group
rank_amax = global_amax.detach().clone()
world_size = dist.get_world_size(group=tp_group)
gathered_amax = [torch.zeros_like(rank_amax) for _ in range(world_size)]
dist.all_gather(gathered_amax, rank_amax, group=tp_group)
amax_span = max(x.item() for x in gathered_amax) - min(x.item() for x in gathered_amax)
assert amax_span == 0.0, f"TP amax mismatch: {amax_span}"
💡 通信节省会快速放大: 在 TP all-gather 中把激活载荷减半,不仅节省带宽,也提升通信与计算的重叠空间。当 collective 变短后,Megatron 更容易把它们隐藏在计算之后,因此端到端收益往往比简单的 2× 载荷缩减更明显。

性能说明:每个张量(激活、权重、梯度)都有自己的 amax 标量,朴素实现意味着每层每步需要三次 all-reduce。在 TP=8、走 InfiniBand 的场景下,即使是微小的 collective 也有延迟代价。解决办法是把同一层所有 amax 标量打包成一个一维 GPU float32 张量,发起一次 all_reduce。Transformer Engine 通过 fp8_meta 张量打包自动完成这一操作——理解这个机制有助于调试自定义包装层导致 amax 意外发散的问题。另外,all_reduce 的 async_op=True 选项允许 NCCL 在 CPU 继续构建下一个算子的计算图时异步发起 collective,将通信延迟隐藏在计算之后。TE 在内部同时利用了这两种技巧。

§12 FP8 训练实践指南

在启动训练时,Megatron-LM 通过少量 CLI 参数把绝大多数 FP8 行为映射到 Transformer Engine 的 recipe 与 kernel 选择上。对于生产级训练来说,关键不是“把 FP8 打开”这么简单,而是根据模型规模、硬件代际和稳定性预算选择合适的 recipe。

参数 取值 默认值 含义
--fp8-format {hybrid, e4m3} hybrid 选择 FP8 数值格式。hybrid 表示前向张量用 E4M3、对反向更敏感的路径用 E5M2;e4m3 则尽可能统一使用 E4M3。
--fp8-amax-history-len 正整数 1024 延迟缩放使用的 amax 历史窗口长度,用于平滑缩放更新。
--fp8-amax-compute-algo {max, most_recent} max 延迟缩放如何从历史中生成当前 amax:取窗口最大值,或只使用最近一次样本。
--fp8-margin 整数 0 通过 2^margin 调整有效缩放因子,为溢出预留安全余量。
--fp8-interval 正整数 1 两次 FP8 recipe 元数据更新之间间隔多少训练步。
--fp8-wgrad 布尔开关 默认关闭 当后端与 recipe 支持时,使权重梯度 GEMM 也走 FP8 路径。
--fp8-recipe {delayed, tensorwise, blockwise} delayed 选择缩放 recipe 家族:历史延迟缩放、按张量即时缩放,或块级 MX 风格缩放。

≤7B 模型

对小模型而言,BF16 往往已经足够省显存且数值宽松。FP8 当然也能工作,但吞吐收益通常有限,因为优化器、数据加载和通信等开销在单步时间中的占比更高。

13B–70B 模型

这是 H100 上 FP8 的甜点区。建议从 --fp8-format hybrid--fp8-recipe delayed--fp8-amax-compute-algo max 开始,并保持 margin=0。通常既能显著减轻显存压力,也能提升 Tensor Core 吞吐,同时不会明显破坏训练稳定性。

70B+ 模型

在超大模型规模下,FP8 往往不只是加速选项,更是容量解锁器。对 Blackwell 而言,块级 MXFP8 尤其有吸引力,因为巨大的 hidden state 和超长序列更需要细粒度缩放。FP8 节省下来的显存余量可以重新投入到更长序列、更大 batch 或更灵活的并行策略上。

较为保守的 bring-up 路线通常也是最容易得到稳定 FP8 训练的方案。先用 BF16 预热约 100–500 步,等优化器统计量和激活尺度稳定后再切到 FP8。梯度裁剪建议设在 1.0–1.5 左右,并持续观察是否出现突发的 loss spike;如果不稳定持续存在,可以增大 amax_history_len,或者从 delayed scaling 切换到 Float8CurrentScaling。使用梯度累积时一定要正确传递 is_first_microbatch,否则每步都会悄悄重复量化权重,白白损失性能。

  1. 先跑一段短 BF16 baseline,记录 loss、tokens/s 和梯度范数,再切换精度。
  2. 优先在单机开启 FP8,从 --fp8-format hybrid--fp8-recipe delayed 起步。
  3. 确认 TE 层已经正确替换、is_first_microbatch 传递无误,且关键 attention 路径没有静默回退。
  4. 切到 FP8 后至少连续观察 500–1000 步的逐层 amax、scale、loss 和梯度范数。
  5. 只有在单机稳定后,再扩展到 TP/PP/DP 规模,并重新检查通信与计算重叠。
  6. 在调整 recipe 家族、margin 或 history length 之前先打一个干净检查点,确保回滚成本足够低。
# Launch FP8 training
python -m torch.distributed.launch \
    --nproc_per_node=8 \
    pretrain_gpt.py \
    --tensor-model-parallel-size 4 \
    --pipeline-model-parallel-size 2 \
    --fp8-format hybrid \
    --fp8-amax-history-len 1024 \
    --fp8-amax-compute-algo max \
    --fp8-margin 0 \
    --fp8-wgrad \
    --num-layers 80 \
    --hidden-size 8192 \
    --num-attention-heads 64 \
    --seq-length 4096 \
    --micro-batch-size 2 \
    --global-batch-size 512 \
    --lr 1.5e-4 \
    --clip-grad 1.5 \
    --bf16
# Lightweight FP8 health monitor
def log_fp8_health(model, step):
    rows = []
    for name, mod in model.named_modules():
        if hasattr(mod, 'fp8_meta'):
            fwd = mod.fp8_meta['scaling_fwd']
            rows.append((name, fwd.amax_history[-1].item(), fwd.scale.item()))
    rows = sorted(rows, key=lambda x: x[1], reverse=True)[:5]
    print(f"[step {step}] top FP8 amax layers")
    for name, amax, scale in rows:
        print(f"  {name:48s} amax={amax:9.3f} scale={scale:9.3e}")
⚠️ 常见 FP8 坑点: 大多数问题并不是神秘的数值灾难,而是工程配置不匹配:忘记先做 BF16 预热、把 margin 设得过于激进、在尖峰梯度场景下使用过短的历史窗口、没有启用 microbatch 权重缓存,或者误以为所有 attention 后端都在跑 FP8——实际上有些 kernel 早已回退到 BF16。排查 FP8 问题时,先检查这些配置,再去怀疑数据格式本身。

在启动 FP8 训练之前,建议对一系列硬件与软件前置条件进行核查。漏掉其中任何一项,通常会导致静默回退到 BF16 kernel、运行时报错,或者难以事后诊断的数値不稳定训练。

前置条件 验证方法 缺失时的后果
GPU 架构 ≥ sm_89(Ada Lovelace)或 sm_90(Hopper) python -c "import torch; print(torch.cuda.get_device_capability())" ——应返回 (8,9) 或 (9,0) TE 会静默回退到 BF16,默认不抛出异常。
CUDA ≥ 12.0 nvcc --versiontorch.version.cuda CUDA 12 前缺少 cuBLAS FP8 GEMM API(cublasLtMatmul FP8 epilog),TE 会在导入或运行时抛出错误。
Transformer Engine ≥ 1.7 python -c "import transformer_engine; print(transformer_engine.__version__)" 旧版本缺少 Float8CurrentScalingMXFP8BlockScalingis_first_microbatch 缓存逻辑也可能缺失或有缺陷。
确认 cuBLAS FP8 支持 python -c "import transformer_engine.pytorch as te; te.fp8_available()" 应返回 True 若返回 False,说明 cuBLAS 未启用 FP8 支持(CUDA 或驱动版本不匹配),所有 TE 层静默使用 BF16。
张量并行 amax 同步需要 NVLink / NVSwitch nvidia-smi topo -m 应显示 TP 对等之间有 NVLink 连接 没有 NVLink,每个分片 amax 的 all-reduce 将用 PCIe,在每个 FP8 缩放步骤引入延迟尖峰,并可能造成性能激降。
已启用梯度裁剪 启动脚本中 --clip-grad 设为 1.0–1.5,通过 args.clip_grad 确认 FP8 E5M2 梯度可表示很大的指数;缺少裁剪时,单个异常梯度可能引发跨层的波及性溢出。

§13 FP8 × 激活检查点

激活检查点(梯度检查点)在前向传播时丢弃中间激活,在反向传播时重新计算它们。引入 FP8 后,这一交互变得更加微妙:当 Transformer Engine 层被 torch.utils.checkpoint 包裹时,激活会在反向传播中以 FP8 格式重新计算——但原始前向传播的 FP8 缩放元数据(amax 值、缩放因子、缩放倒数)必须跨检查点边界保留,否则重新计算的激活会使用过期或默认的缩放因子,从而静默地降低精度。

💾 显存节省

FP8 激活每个元素占 1 字节,BF16 则为 2 字节——即使不使用激活检查点,激活显存也大约减少 50%。与激活检查点结合后,节省效果会叠加:需要保存的激活本就更少,而每个保存的激活只占 1 字节。过去需要开激活检查点才能塞进显存的大模型,现在在相同 batch size 下可能不需要它了。

🔧 Megatron 参数

  • --recompute-activations:对 attention/dropout 启用选择性重算
  • --recompute-granularity full|selective:重算整个 transformer 层还是仅重算 attention 子层
  • --recompute-num-layers N:需要重算的层数(与 pipeline stage 配合)

Transformer Engine 通过在检查点段上注册自定义 pack_hookunpack_hook 回调来处理检查点边界问题。当 torch.utils.checkpoint 保存和恢复段状态时,TE 的 hook 确保 FP8 元数据张量(存储在 fp8_meta 中的 scale、amax 历史、scale 倒数)与激活一起被打包,并在重算前向之前解包。这保证了重算的 FP8 GEMM 使用与原始前向完全相同的缩放因子,从而维持数值一致性。

前向传播计算并保存缩放因子
检查点边界丢弃激活,保留缩放因子
反向重算用已保存的缩放因子重算激活
梯度 FP8 GEMM权重梯度与输入梯度
# Megatron activation checkpointing with FP8
# Launch flags:
#   --recompute-activations
#   --recompute-granularity selective   # recompute attention only
#   --recompute-num-layers 24           # recompute bottom 24 layers

# TE handles FP8 metadata preservation via pack/unpack hooks.
import transformer_engine.pytorch as te

# fp8_autocast manages metadata lifecycle across checkpoint boundaries
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
    # torch.utils.checkpoint wraps TE layers — TE hooks preserve fp8_meta
    output = torch.utils.checkpoint.checkpoint(
        te_transformer_layer,
        hidden_states,
        use_reentrant=False,  # recommended for TE compatibility
    )
# Small pack/unpack hook example with a TE module
import torch
import transformer_engine.pytorch as te

layer = te.Linear(hidden_size, hidden_size, bias=False).cuda()

def pack_hook(tensor):
    return tensor, layer.get_extra_state()

def unpack_hook(packed):
    tensor, fp8_state = packed
    layer.set_extra_state(fp8_state)
    return tensor

with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
    y = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False)
# Verification: checkpointed path should stay numerically close
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
    ref_out = layer(x.clone())
    ckpt_out = torch.utils.checkpoint.checkpoint(layer, x.clone(), use_reentrant=False)

torch.testing.assert_close(ref_out, ckpt_out, rtol=5e-2, atol=5e-2)
assert torch.isfinite(layer.fp8_meta['scaling_fwd'].scale).all()

理解 get_extra_state / set_extra_stateget_extra_state() 返回一个 Python 字典,包含该 TE 模块的 fp8_meta 张量——具体来说是前向缩放张量(scaling_fwd.scalescaling_fwd.scale_invscaling_fwd.amax_history)以及对应的反向版本。这些均为 GPU float32 张量,与层参数在同一 CUDA 设备上。set_extra_state() 接受相同的字典结构,并在重算前向开始之前恢复模块的 fp8_meta 状态。saved_tensors_hooks 合约很直接:pack_hook 接收张量,可返回任意可 pickle 序列化的 Python 对象(此处为张量与状态字典的元组);unpack_hook 接收 pack_hook 的返回内容,必须返回张量。设备注意:extra-state 张量在 CUDA 上;如果检查点策略会将它们卸载到 CPU 以节省显存,请在 unpack_hook 中调用 set_extra_state() 之前显式执行 .to(device),否则恢复后的 scale 将在错误设备上,导致后续 FP8 GEMM 静默失败或抛出设备不匹配报错。

  1. 始终对 TE 层使用 use_reentrant=False——可重入模式不会触发 pack/unpack hook,从而静默绕过 FP8 元数据保存。
  2. 对检查点段内的每个 TE 模块分别注册 pack/unpack hook,而不是只为最外层包装器注册;嵌套的 TE 层各自维护独立的 fp8_meta
  3. 开发阶段用激活检查点后,用 assert_close(如上方代码所示)验证数值一致性;对 FP8 精度允差使用 rtol=5e-2, atol=5e-2 作为安全阈値。
  4. 警惕静默精度退化:在前 1000 步并行运行开关检查点的两个实验,对比 loss 曲线;持续大于 0.5% 的差距就是警示信号。
  5. 如果你对 TE 层有自定义包装器,请在包装器中重写 get_extra_state / set_extra_state,迭代所有内嵌 TE 子模块,确保状态能正确传递。
⚠️ 静默精度陷阱: 如果在使用激活检查点时 TE 的 pack/unpack hook 没有正确注册——例如在较旧的 TE 版本上使用 use_reentrant=True,或手动对内部调用 TE 层的非 TE 包装器做检查点——重算前向会使用新初始化的或全局共享的缩放因子,而非原始前向的每步缩放因子。模型或许仍能收敛,但速度明显变慢且最终效果更差:这是一种没有明显报错信号的静默精度退化。

§14 FP8 调试与监控工具箱

FP8 的问题很少会以清晰的报错形式出现,更常见的是 loss 突增、精度缓慢下降、莫名的速度退化,或者在训练数千步后才出现的 NaN 梯度。诊断这些问题需要了解哪些手段能暴露 TE 缩放机制的内部状态。

🔍 环境变量

  • NVTE_DEBUG=1:开启 TE 详细 FP8 日志,记录 kernel 调度、缩放因子更新和格式回退
  • NVTE_FP8_DPA_BWD=0/1:控制反向注意力 kernel 是否使用 FP8
  • CUDA_LAUNCH_BLOCKING=1:串行化 GPU kernel 启动,便于同步定位报错

📊 溢出 / 下溢检测

当 amax 超过 FP8 最大可表示值(E4M3:448,E5M2:57344)时,缩放因子应当饱和。如果没有饱和,Inf 或 NaN 就会在网络中传播。监控梯度范数和逐层激活幅度;范数突然飙升通常是 FP8 溢出事件的前兆。

症状 可能原因 解决方法
切换 FP8 时 loss 突增 缩放因子尚未预热 先用 BF16 预热 100–500 步
精度缓慢下滑 margin 设置过于激进 将 margin 从 4 调低到 0
梯度中出现 NaN E4M3 激活溢出 将梯度格式切换为 E5M2,检查 amax
训练比 BF16 还慢 Kernel 回退到 BF16 检查 NVTE 日志,确认 SM≥89
Loss 反复震荡 历史窗口过短 将 amax_history_len 增大到 1024+
# Inspect FP8 metadata on each TE module
for name, mod in model.named_modules():
    if hasattr(mod, 'fp8_meta'):
        fwd = mod.fp8_meta['scaling_fwd']
        print(f"{name}: amax={fwd.amax_history[-1].item():.2f}, scale={fwd.scale.item():.4f}")
        # flag if scale looks degenerate
        if fwd.scale.item() > 1e6 or fwd.scale.item() < 1e-6:
            print(f"  warning: degenerate scale detected: {name}")
# Sample NVTE_DEBUG output
[NVTE][DEBUG] layer=decoder.layers.12.mlp.linear_fc1 fp8_fwd=True format=E4M3
[NVTE][DEBUG] layer=decoder.layers.12.mlp.linear_fc1 amax=183.25 scale=2.4453e+00 scale_inv=4.0894e-01
[NVTE][DEBUG] layer=decoder.layers.12.self_attention fallback=bf16 reason=unsupported_head_dim
[NVTE][DEBUG] layer=decoder.layers.12.mlp.linear_fc2 cache_hit=True is_first_microbatch=False
# ── Multi-node NVTE_DEBUG setup ──
# In your torchrun launch wrapper (e.g. launch.sh or SLURM sbatch script):
export NVTE_DEBUG=1
export NVTE_DEBUG_LEVEL=1  # 0=errors only, 1=info, 2=verbose kernel dispatch

# Logs go to stdout, which torchrun redirects per-rank:
#   torchrun --redirects 1:train_log_rank0.txt,2:err_rank0.txt ...
#   (or Megatron captures to log_dir/rank_N/stdout.log)

# ── Filter log for fallback / amax / scale events ──
grep "[NVTE]" train_log.txt | grep -E "(fallback|amax|scale)"

# Example hit:
#   [NVTE][DEBUG] layer=...self_attention fallback=bf16 reason=unsupported_head_dim
#   ^ means: FP8 attention kernel does NOT support this head_dim — silently uses BF16
#   Fix: set head_dim to a supported value (64, 128) or disable fp8_mha
# Check GPU compute capability (SM≥89 required for FP8 Tensor Cores)
nvidia-smi --query-gpu=compute_cap --format=csv,noheader
# Expected output:
#   8.9   ← L40S / RTX 4090 (minimum for FP8 Tensor Cores)
#   9.0   ← H100 SXM / H100 PCIe
#   10.0  ← B100 / B200
# If compute_cap < 8.9: FP8 kernels silently fall back to BF16
#   ─ no error, no warning; throughput regresses to BF16 level
#   ─ NVTE_DEBUG=1 will show: fallback=bf16 reason=unsupported_compute_capability
# Nsight Systems commands for FP8 profiling
nsys profile --trace=cuda,nvtx,osrt \
    --sample=none \
    --capture-range=nvtx \
    --capture-range-end=stop \
    -o fp8_profile \
    python pretrain_gpt.py ...

nsys stats --report cuda_gpu_kern_sum,gpu_metric_utilization fp8_profile.nsys-rep

推荐的分析工作流程:(1)开启 NVTE_DEBUG=1,运行 200 个训练步,然后用 grep 过滤日志中的 fallback、amax 和 scale 事件——这个步骤只需几秒,却能拦截大多数问题。(2)使用 nsys 的 --capture-range=nvtx 并配合 NVTX 标记,只采样第 100–110 步——窬取窄窄窗口可避免生成数十 GB 的大文件脑型。(3)采样完成后,执行 nsys stats --report cuda_gpu_kern_sum:查看 kernel 名称列中 gemm_(BF16)与 fp8_gemm_ 条目——如果 BF16 GEMM kernel 与 FP8 kernel 共存,说明存在混合精度回退。(4)如果发现混合 kernel,对照 NVTE_DEBUG 日志定位具体是哪些层及其原因(head_dim 不匹配、计算能力不足或 recipe 标志设置错误)。

💡 最常见的调试误区: 误以为所有层都在跑 FP8,而实际上某些 attention kernel 已经静默回退到 BF16。这种情况发生在 flash-attention 后端对当前序列长度或 head 配置没有 FP8 路径时、recipe 中设置了 fp8_mha=False 时,或者在 SM<89 的 GPU 上运行时。模型仍能正确运行,但这些层的吞吐收益已经消失——分析器会显示 FP8 和 BF16 矩阵乘 kernel 奇怪地混杂在一起。

§15 FP8 训练 vs INT8 / INT4 推理量化

FP8 与 INT8/INT4 常被笼统归为"低比特格式",但它们在模型生命周期中扮演着根本不同的角色。FP8 是训练时(training-time)格式:前向和反向传播都以降低精度运行,需要足够的动态范围来表示在不同层和训练步骤中变化的梯度幅度。INT8 和 INT4 则是推理时(inference-time)格式:仅有前向传播,通过静态量化或校准(calibration)完成量化,无需表示梯度。

维度 FP8 (E4M3/E5M2) INT8 INT4 (GPTQ/AWQ)
阶段 训练 + 推理 推理 推理
表示方式 浮点数 定点数 定点数
动态范围 宽(有指数位) 窄(均匀量化) 非常窄
校准方式 在线(amax 追踪) 训练后校准 (PTQ) 训练后校准 (PTQ)
反向传播 ✅ 支持(E5M2 梯度) ❌ 不支持 ❌ 不支持
显存节省 相对 BF16 节省 2× 相对 FP16 节省 2× 相对 FP16 节省 4×
精度影响 loss 增加 <0.5% 下降 0.5–1% 下降 1–3%
硬件要求 Hopper+ (SM≥89) 大多数 GPU (Turing+) 大多数 GPU
框架支持 Megatron+TE、FSDP TensorRT-LLM、vLLM GPTQ、AWQ、vLLM

🏋️ 训练时量化(FP8)

FP8 训练保持 master 权重为 BF16/FP32;只有 GEMM 输入(激活、权重、梯度)在计算时临时降为 FP8。优化器状态、层归一化和 loss 计算仍保持更高精度。这意味着 FP8 训练与 loss scaling、梯度裁剪等混合精度技术完全兼容。

🚀 推理时量化(INT8/INT4)

INT4 量化(GPTQ、AWQ)将 BF16 检查点压缩为每权重 4 位,使推理时权重显存减少 4 倍。经过 FP8 训练的模型对量化特别友好,因为模型已经学会了对精度降低的鲁棒性。推荐工作流:FP8 训练 → 保存 BF16 检查点 → 应用 GPTQ/AWQ → 通过 TensorRT-LLM 或 vLLM 部署。

# Example export pipeline: FP8-trained checkpoint -> INT8 / GPTQ artifacts
python tools/checkpoint/convert.py \
    --loader megatron \
    --saver huggingface \
    --load-dir /ckpts/fp8_iter_200000 \
    --save-dir /tmp/hf_bf16_export \
    --target-dtype bf16

trtllm-build \
    --checkpoint_dir /tmp/hf_bf16_export \
    --weight_only_precision int8 \
    --output_dir /tmp/trtllm_int8

python quantize_gptq.py \
    --model /tmp/hf_bf16_export \
    --bits 4 \
    --group-size 128 \
    --dataset c4 \
    --output /tmp/gptq_int4
  1. 在同一份验证集上对比导出前后的困惑度或任务精度。
  2. 确认 tokenizer 文件、RoPE 配置和 special token ID 在检查点转换后仍然一致。
  3. 准备一组固定提示词,逐条对比生成结果,而不只看聚合指标。
  4. 完成 INT8/GPTQ 导出后重新测量时延、吞吐与显存占用,因为不同 kernel 的速度/质量权衡不同。
  5. 保留一份 BF16 导出权重,方便之后更换 group size 或校准数据重新量化,而不必回炉训练。
💡 互补而非竞争: FP8 训练与 INT4 推理是互补关系,而非竞争关系。最优的工作流是:用 FP8 预训练 → 用 FP8 微调 → 量化为 INT4 用于部署。FP8 也可以直接在推理阶段使用(TensorRT-LLM 的 FP8 模式),实现 2× 吞吐量提升,精度损失极小,且无需校准数据——直接复用训练得到的 FP8 缩放因子即可。

硬件支持是选择 FP8 与 INT8 的决定性因素。FP8 E4M3/E5M2 需要 Hopper 级硬件(SM≥89,即 H100/H800)方可获得 Tensor Core 原生加速——Ampere GPU 会回退到软件模拟。相比之下,对称量化 INT8 通过 DP4A 指令从 Turing+ 起即获支持(RTX 2080、A100、H100),自 CUDA 10.2 起在 TensorRT/cuDNN 中已经相当成熟。非对称 INT8 额外引入了每张量的零点(zero-point),在同代硬件上同样可用。对于没有 Hopper 资源的团队,INT8 推理量化是务实之选;对于拥有 H100 集群的团队,FP8 训练 + INT8/INT4 推理导出是推荐的完整工作流。

对比维度 FP8(E4M3/E5M2) INT8(对称量化) INT8(非对称量化)
动态范围 宽:E4M3 最大值=448,E5M2 最大值=57344 窄:scale × [-127, 127] 窄:scale × [-128, 127] + 零点
硬件支持 Hopper+(SM≥89),Tensor Core FP8 Turing+(SM≥75),DP4A 指令 Turing+(SM≥75),DP4A 指令
校准需求 在线 amax 追踪(无需离线校准) 需要训练后校准(PTQ) PTQ + 每张量零点调整
训练阶段支持 ✅ 完整支持(前向+反向,E5M2 梯度) ⚠️ 有限(仅 Q-AT,实践中少见) ❌ 训练阶段不使用
推理加速(相比 FP16) 1.5–2×(TensorRT-LLM FP8 模式) 1.3–1.7×(仅权重 INT8) 1.3–1.6×(kernel 略复杂)
框架成熟度 Megatron+TE、PyTorch FSDP(2024+) TensorRT-LLM、vLLM、ONNX Runtime(成熟) TensorRT-LLM、ONNX Runtime(良好支持)

§16 Blackwell 上的 MXFP8:分块缩放 FP8

MXFP8 是 OCP Microscaling(MX)格式,其中每连续 32 个元素共享一个 E8M0 缩放因子——8 位指数、无尾数、无符号位,表示从 2^(-127) 到 2^(127) 的 2 的幂次。块内的每个独立元素仍采用标准的 E4M3 或 E5M2。这是 NVIDIA Blackwell(B100/B200)上的原生硬件格式:Tensor Core 在硬件级别原生支持分块缩放 FP8 GEMM,E8M0 缩放向量在累加时由硬件直接施加,软件侧的反缩放(descaling)开销为零。

🔬 为什么需要分块缩放

当一个张量同时包含平稳区域和局部异常值(outlier)时,按张量(per-tensor)缩放(Hopper 采用的方式)过于粗糙。某一通道的单个异常值可能迫使整个张量采用保守缩放,浪费了所有其他通道的动态范围。分块(每 32 个元素一块)缩放能更好地捕捉局部动态范围:异常值所在块使用小 scale,平稳块使用大 scale,整个张量的量化误差大幅降低。

📡 张量并行与分块缩放的交互

有了分块缩放,每个 TP rank 的分片独立维护自己的块 scale,无需跨 rank 的 scale 同步(张量并行)。这是相对于按张量缩放的显著优势——后者在转换前需要在 TP 组内做 all-reduce 来获取全局 amax。MXFP8 彻底消除了 scale 计算的 TP 通信开销。

按张量(Hopper)每张量 1 个 scale
按块(Blackwell)每 32 元素 1 个 scale
按元素(理论极端)类似 FP16,无共享
from transformer_engine.common.recipe import MXFP8BlockScaling
import transformer_engine.pytorch as te

# MXFP8 recipe — Blackwell only (SM≥100)
recipe = MXFP8BlockScaling()

with te.fp8_autocast(fp8_recipe=recipe):
    output = model(input)
    # Block scales computed on-the-fly by hardware
    # No amax history, no margin, no warm-up steps needed

# E8M0 scale format: 8 exponent bits, no mantissa, no sign
# Represents 2^(-127) to 2^(127) — powers of 2 only
# Applied per-block of 32 elements during GEMM accumulation
# MXFP8 verification on Blackwell
capability = torch.cuda.get_device_capability()
assert capability[0] >= 10, f"MXFP8 requires Blackwell, got SM{capability[0]}{capability[1]}"

with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
    y = model(x)

print(type(recipe).__name__, y.dtype)
assert type(recipe).__name__ == 'MXFP8BlockScaling'
GPU 按张量 FP8 CurrentScaling MXFP8 分块缩放 说明
H100 ✅ 支持 ✅ 支持 ❌ 无原生支持 Hopper 时代的按张量 FP8,TP 下仍需同步 amax
B100 ✅ 支持 ✅ 支持 ✅ 原生支持 Blackwell Tensor Core 在硬件中原生支持 E8M0 分块 scale
B200 ✅ 支持 ✅ 支持 ✅ 原生支持 最适合大序列场景下的 MXFP8 训练与推理
属性 DelayedScaling(Hopper) MXFP8BlockScaling(Blackwell)
缩放粒度 按张量 每 32 元素一块
Scale 格式 FP32 E8M0(8 位指数)
Scale 开销 可忽略(每张量 1 个) ~3%(每 32 元素 1 个)
量化误差 较高(对异常值敏感) 较低(局部自适应)
硬件要求 Hopper (SM≥89) Blackwell (SM≥100)
amax 追踪 必须(历史窗口) 不需要(块内局部)
TP 通信 需要 amax all-reduce 无需 scale 同步

Microscaling 的开销很好估算。每个块存放 32 个 FP8 数值(32 字节)外加 1 个 8-bit 的 E8M0 scale(1 字节),因此原始元数据开销就是 1 / 32 = 3.125%。例如一个包含 1,048,576 个 FP8 元素的张量,数值本体占 1,048,576 字节,块 scale 额外占 32,768 字节,总计 1,081,344 字节——相比为每个元素单独保存浮点 scale,这个成本要低得多。

💡 MXFP8 消除了脆弱的调参环节: MXFP8 消除了 FP8 训练中最脆弱的部分——延迟缩放(delayed scaling)方案及其 amax 历史调优。有了硬件在线计算的块内局部 scale,'margin'、'历史窗口'和'预热步数'这些旋钮全都不再需要。这极大地降低了在 Blackwell 上使用 FP8 训练的门槛:一个无需任何参数的 MXFP8BlockScaling() 就能替代 Hopper 上所有繁琐的手动调优。

在 Hopper 上,计算 FP8 GEMM 需要软件循环追踪每张量 amax、施加 margin 并延迟一步更新 scale。在 Blackwell B200/B100(第五代 Tensor Core,SM≥100)上,E8M0 缩放向量直接由 Tensor Core 在 GEMM 累加期间施加,软件竧无需任何 amax 管理。NVIDIA 公布 B200 的 FP8 浮点吸量大约达到 9 PFLOPS,是 H100(3.9 PFLOPS)的约 2.3×;HBM3e 内存子系统提供约 8 TB/s 带宽,是 H100(3.35 TB/s)的约 2.4×,对内存带宽限制型的 attention 与 all-gather 操作尤为关键。

§17 FP8 注意力核:Flash Attention 与 TE DPA

并非所有注意力(attention)后端都支持 FP8 输入。Transformer Engine 的 DotProductAttention(DPA)是 Megatron+TE 中主要支持 FP8 的注意力实现,它会调度到 cuDNN 的融合注意力核(fused attention kernel),能在 FP8 精度下执行 QKT 及 softmax·V 计算。标准 PyTorch 的 F.scaled_dot_product_attention 以及 Tri Dao 的原版 Flash Attention 均不支持 FP8 输入——它们只接受 BF16 或 FP16。

TE DPA 的内部机制:TEDotProductAttention 封装了 cuDNN 的融合多头注意力(fused multi-head attention)。开启 FP8 注意力后,在前向传播中,Q/K/V 会在 QKT 矩阵乘之前被转换为 FP8。Softmax 始终以 FP32 计算(从不量化)。输出投影(V·softmax)可以使用 FP8。注意力的反向传播由环境变量 NVTE_FP8_DPA_BWD 控制——默认以 BF16 运行,因为注意力梯度对量化噪声尤为敏感。

Q / K / V 投影E4M3 转换并更新 amax
QKT 融合注意力cuDNN FP8 核
Softmax始终为 FP32
softmax·V 与输出FP8 前向 / BF16 回退

实际调度规则是:只有当 Megatron 走 TE 的 DPA 路径、运行在 Hopper 级 GPU 上,并且 head 维度、mask 类型、dropout 模式与序列布局都满足 cuDNN 约束时,FP8 注意力才会被启用。只要其中任一条件不满足——例如 head 维度不受支持、因果或掩码布局不兼容,或者代码路径走到了 PyTorch SDPA 而不是 TE DPA——内核就会静默回退到 BF16。应将 FP8 注意力视为“受约束的快速路径”,而不是必然命中的默认路径。

🚩 NVTE_FP8_DPA_BWD 标志

NVTE_FP8_DPA_BWD=0(默认):注意力反向以 BF16 运行——更安全,在敏感任务上保留精度。NVTE_FP8_DPA_BWD=1:注意力反向以 FP8 运行——注意力核速度提升约 10–15%,但在敏感任务上可能导致 loss 增加 0.1–0.3%。这一开关对 FP8 训练质量的单开关影响最大。

⚠️ 静默回退行为

当 FP8 注意力无法被调度时(不支持的 head 维度、不支持的掩码类型、因果掩码限制等),TE 会静默地回退到 BF16 注意力。唯一能检测到这一点的方式是开启 NVTE_DEBUG=1 日志。这种静默回退是 "FP8 比预期慢" 这类问题的常见根源——FP8 注意力的 scale 管理开销仍然存在,但 FP8 核的吞吐收益已经消失。

后端 FP8 前向 FP8 反向 硬件要求 备注
TE DPA(cuDNN) 可选(NVTE_FP8_DPA_BWD) Hopper+ Megatron 默认后端
Flash Attention v2(Tri Dao) Ampere+ 仅支持 BF16/FP16
PyTorch SDPA 任意 内部调度到 Flash/cuDNN
cuDNN 原生 Hopper+ 底层接口,被 TE 使用
import os
os.environ["NVTE_DEBUG"] = "1"
# In TE debug logs, look for:
#   [FP8] DPA forward: fp8=True, backend=cudnn
#   [FP8] DPA backward: fp8=False (NVTE_FP8_DPA_BWD=0)

# To enable FP8 backward for attention:
os.environ["NVTE_FP8_DPA_BWD"] = "1"

# Megatron attention spec uses TE's DPA
from megatron.core.extensions.transformer_engine import TEDotProductAttention
attn_spec = SelfAttentionSubmodules(
    core_attention=TEDotProductAttention,  # FP8-aware
    # vs nn.MultiheadAttention → no FP8
)
注意力反向传播是 FP8 训练中精度最敏感的操作。dV = softmaxT · dO 和 dQ = dO · KT 这两个矩阵乘涉及将极小的 softmax 概率值与梯度相乘,产生的数值跨越多个数量级。这正是 NVTE_FP8_DPA_BWD=0(BF16 反向)作为安全默认值的原因——切换到 FP8 反向可节省约 10–15% 的注意力核时间,但在敏感任务上可能导致 loss 可测量地上升 0.1–0.3%。

要在 Transformer Engine 中启用 FP8 注意力,需在 FP8 recipe 中设置 fp8_mha=True。启用后,FP8 注意力会在 QKT 矩阵乘之前将 Q 和 K 量化为 FP8,并将 softmax·V 输出量化为 FP8,但 softmax 本身始终以 FP32 运行。FP8 注意力仅在序列长度达 8K 以上时才能带来显著提升。

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling, Format

# Enable FP8 MHA (multi-head attention) in the TE recipe
recipe = DelayedScaling(
    fp8_format=Format.HYBRID,       # E4M3 forward, E5M2 backward
    amax_history_len=16,
    amax_compute_algo="max",
    fp8_mha=True,                   # <-- enable FP8 attention
)

# Precision breakdown with fp8_mha=True:
#   Q cast: BF16 -> E4M3 (before QK^T matmul)
#   K cast: BF16 -> E4M3 (before QK^T matmul)
#   QK^T:   E4M3 x E4M3, FP32 accumulator
#   Softmax: ALWAYS FP32 (never quantized)
#   V cast: BF16 -> E4M3
#   softmax*V: E4M3 x E4M3, FP32 accumulator
#   Output projection: FP8 forward, BF16 by default for backward

# For very short sequences (e.g., seq_len <= 2048), disable FP8 MHA:
#   fp8_mha=False  (default) -- attention memory is not the bottleneck

# Verify FP8 attention is active (requires NVTE_DEBUG=1):
#   export NVTE_DEBUG=1
#   Look for: [FP8] DPA forward: fp8=True, backend=cudnn
#   If fp8=False appears, attention silently fell back to BF16

with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
    output = model(input_ids)  # QK^T in FP8, softmax in FP32, V*softmax in FP8

§18 FP8 的损失缩放与优化器交互

与 FP16 混合精度训练必须使用动态损失缩放(dynamic loss scaling)来防止梯度下溢不同,使用 E5M2 格式的 FP8 梯度具有足够的动态范围,通常不需要损失缩放。E5M2 的最大值为 57344——比 E4M3 多 5 个指数位——因此梯度下溢(gradient underflow)极为少见。Megatron 仍然支持 --fp16-lgs(损失梯度缩放器)以应对边缘情况。E5M2 用于梯度正是因为其更宽的动态范围能应对不同层梯度幅度跨越多个数量级的情况。

优化器状态精度(optimizer state precision):无论是否启用 FP8 训练,Adam/AdamW 始终以 FP32 维护一阶矩(m)和二阶矩(v)。主权重(master weights)也是 FP32。FP8 只影响 GEMM 计算,不影响优化器状态的存储。这正是 FP8 节省显存 "仅有 2×" 的原因——体现在激活和通信带宽上,而非总 GPU 显存的 8× 节省。

✅ 策略 A:默认不启用损失缩放

适用于标准 FP8 hybrid 训练:梯度使用 E5M2,并立即在 BF16 或 FP32 中累积。这样实现最简单,不需要额外的溢出 bookkeeping,也最符合 Hopper 上常见的 Megatron+TE 配方。

🛟 策略 B:为边缘场景保留缩放器

只有在梯度异常尖峰的场景下才重新开启损失缩放——例如冷启动恢复检查点、课程切换导致分布突变,或某些自定义算子延迟上转精度时。此时缩放器更像安全护栏,而不是 FP8 的基础依赖。

一个显存计算示例:假设模型规模为 13B,即 130 亿参数。无论是否启用 FP8,优化器状态大约都需要 130 亿 ×(4B 主权重 + 4B 一阶矩 + 4B 二阶矩)≈ 156 GB。计算副本从 BF16 的 26 GB 降到 FP8 的 13 GB,只节省 13 GB。若当前 batch 和序列形状下激活显存为 48 GB(BF16),则 FP8 可降至约 24 GB,再节省 24 GB。最终总训练显存大约从 230 GB 下降到 193 GB——这很有价值,但显然不是 8×,因为优化器状态依然是主导项。

组件 BF16 训练 FP8 训练 节省
模型参数(主权重) FP32(4 字节) FP32(4 字节)
模型参数(计算用) BF16(2 字节) FP8(1 字节)
Adam 一阶矩(m) FP32(4 字节) FP32(4 字节)
Adam 二阶矩(v) FP32(4 字节) FP32(4 字节)
激活值 BF16(2 字节) FP8(1 字节)
梯度(通信) BF16(2 字节) E5M2(1 字节)
每参数总计(优化器) 16 字节 16 字节
主权重FP32
转换为 FP8前向 GEMM
E5M2 梯度反向 GEMM
上转为 FP32累积前
Adam 更新FP32 优化器步
# Precision flow in FP8 training with Adam
# Master weights: FP32 (never quantized)
# Forward: weights cast to FP8 E4M3, activations cast to FP8 E4M3
# GEMM: FP8 inputs, FP32 accumulator
# Backward: gradient GEMMs produce E5M2 gradients
# Gradient accumulation: upcast to BF16/FP32 immediately
# Optimizer step: Adam updates master weights in FP32

# Megatron flags:
#   --fp8-format hybrid      # E4M3 forward, E5M2 backward
#   --no-fp16-lgs            # no loss scaling needed for FP8
#   --accumulate-allreduce-grads-in-fp32  # safe gradient accumulation
一个常见误解是 FP8 训练相比 FP32 节省 8× 显存。实际上,优化器状态(Adam 的 m、v 以及主权重)仍然以 FP32 存储——这些每参数占 12 字节,是显存的主导项。FP8 节省的是激活值显存(相比 BF16 节省 2×)和通信带宽(相比 BF16 节省 2×),但优化器状态不受影响。对于 70B 模型而言,FP8 节省约 40% 的激活显存,但仅节省约 15% 的总 GPU 显存。

尽管使用 E5M2 梯度的 FP8 训练很少需要损失缩放,Megatron 仍然集成了 GradScaler 兼容性。当 FP8 与 BF16 累加组合时,E5M2 梯度可即时上转 BF16 后进行 all-reduce。E5M2 指数位比 E4M3 多 1 位(最大可表示值 57344),溢出频率远低于 FP16,损失缩放器即使启用也可很激进甚至完全禁用。

§19 逐层 FP8 敏感度与选择性精度

不同的 transformer 层对 FP8 量化的敏感度差异很大(逐层敏感度分析)。前几层和最后几层(嵌入投影、最终的 LM head)通常最为敏感,因为它们处理完整的词汇表分布。中间层则更为鲁棒。在一个 transformer block 内部,注意力 QKV 投影的激活值幅度方差往往大于 MLP 层,这使得 MLP 层通常对 FP8 更友好。

Megatron+TE 通过 ModuleSpec 系统支持逐模块控制(选择性精度)。可以在同一模型中混合使用 TE FP8 模块和标准 BF16 模块。一种实用的选择性策略:嵌入层/LM head 使用 BF16,前 2 层和后 2 层 transformer 使用 BF16,中间层使用 FP8。逐层 amax 分析——检查每个模块的 fp8_meta——可以找出 amax 方差最大(不稳定性指标)的层。步间 amax 变化超过 10× 的层是 BF16 回退的候选对象。

收集逐层 amax 统计
按不稳定度排序模块cv / 异常率
将边缘层保留为 BF16
将稳定的中间层调度到 FP8

一个实用的自动化流程是:先用 BF16 或混合模式跑一个短窗口,记录逐层 amax 变异系数,再据此生成 ModuleSpec 的白名单或黑名单。每隔几千步或恢复训练后重新计算一次排序;如果原本稳定的层跨过阈值,就自动切回 BF16。这样可以避免永久硬编码层号,并让精度调度始终由真实的不稳定性指标驱动。

✅ FP8 友好层

  • 中间 transformer 块(第 3 层到倒数第 2 层)
  • 所有层的 MLP 投影(fc1/fc2)
  • 注意力输出投影(value projection)
  • 中间层的注意力 QKV

⚠️ 建议使用 BF16 的层

  • 嵌入层 / LM head(全词汇表范围)
  • 前 2 层和后 2 层 transformer
  • 浅层的注意力 QKV(amax 方差大)
  • 任何步间 amax 变化超过 10× 的层
# Per-layer FP8 sensitivity analysis via amax variance
amax_stats = {}
for name, mod in model.named_modules():
    if hasattr(mod, 'fp8_meta'):
        fwd = mod.fp8_meta['scaling_fwd']
        hist = fwd.amax_history  # [history_len] tensor
        amax_stats[name] = {
            'mean': hist.mean().item(),
            'std':  hist.std().item(),
            'max':  hist.max().item(),
            'cv':   hist.std().item() / (hist.mean().item() + 1e-7),
        }

# Sort by coefficient of variation (higher = more unstable)
for name, s in sorted(amax_stats.items(), key=lambda x: -x[1]['cv']):
    print(f"{name}: cv={s['cv']:.2f}, mean={s['mean']:.1f}, max={s['max']:.1f}")
# Layers with cv > 2.0 are candidates for BF16 fallback
层类型 典型 amax 范围 amax 变异系数 FP8 建议
嵌入投影 10–500 >3.0 建议 BF16
注意力 QKV(第 0–1 层) 5–200 2.0–3.0 BF16 更安全
注意力 QKV(中间层) 2–50 0.5–1.0 FP8 可用
MLP fc1/fc2(所有层) 1–30 0.3–0.8 FP8 极佳
注意力输出投影 1–20 0.4–0.9 FP8 可用
LM head 20–1000 >4.0 强烈建议 BF16
LM head(最终词汇表投影层)是 FP8 中最危险的层。其输出跨越完整的词汇表 logit 范围——从强负值到强正值——产生的极端动态范围使得 E4M3 的最大值 448 在不进行激进缩放的情况下无法表达。如果你发现训练初期(logit 最不稳定时)出现特定的 loss 突刺,请尝试将 LM head 保留在 BF16 精度,而其他所有层使用 FP8。

以下代码片段展示了如何将逐层缩放因子监控直接嵌入训练循环。以固定间隔(例如每 50 步)遍历所有模块,从 fp8_meta['scaling_fwd'] 中读取当前 scale 张量,与所有层的运行中位数进行比较,并将缩放因子超过中位数 10 倍的层标记为 BF16 回退的候选异常层。

import torch
import statistics

def monitor_fp8_scales(model, step, log_interval=50):
    """Iterate over TE modules, extract per-layer FP8 scales,
    detect outlier layers where scale > 10x median."""
    if step % log_interval != 0:
        return

    scales = {}
    for name, mod in model.named_modules():
        if not hasattr(mod, 'fp8_meta'):
            continue
        fwd = mod.fp8_meta['scaling_fwd']
        # scale shape: [num_fp8_tensors] — typically 3 (input, weight, output)
        scale_vals = fwd.scale.float().tolist()
        amax_vals  = fwd.amax_history[0].float().tolist()
        scales[name] = {
            'scale_input':  scale_vals[0],
            'scale_weight': scale_vals[1],
            'amax_input':   amax_vals[0],
            'amax_weight':  amax_vals[1],
        }

    if not scales:
        return

    # Compute per-tensor-slot median across all layers
    all_scale_input  = [v['scale_input']  for v in scales.values()]
    all_scale_weight = [v['scale_weight'] for v in scales.values()]
    median_input  = statistics.median(all_scale_input)  + 1e-9
    median_weight = statistics.median(all_scale_weight) + 1e-9

    outliers = []
    for name, v in scales.items():
        ratio_i = v['scale_input']  / median_input
        ratio_w = v['scale_weight'] / median_weight
        if ratio_i > 10.0 or ratio_w > 10.0:
            outliers.append((name, ratio_i, ratio_w))
            print(f"[FP8-WARN] step={step} OUTLIER layer: {name} "
                  f"scale_input×median={ratio_i:.1f}, scale_weight×median={ratio_w:.1f}")
            print(f"  → candidate for BF16 fallback (amax_input={v['amax_input']:.3f})")

    return outliers  # caller may use this list to dynamically rebuild ModuleSpec

# --- Usage inside training loop ---
for step, batch in enumerate(dataloader):
    loss = model(batch).loss()
    loss.backward()
    optimizer.step()
    monitor_fp8_scales(model, step, log_interval=50)

§20 FP8 检查点与迁移实战手册

经 FP8 训练的模型以 FP32/BF16 格式保存主权重(master weights)检查点——FP8 元数据(scales、amax 历史)默认不保存到检查点中。FP8 检查点与 BF16 检查点格式兼容:可以在 BF16 模式下恢复 FP8 训练的模型(反之亦然),无需任何格式转换。这种格式无关性是 FP8 最被低估的实用优势之一。

训练中途从 BF16 迁移到 FP8:(1)保存 BF16 检查点,(2)使用 --fp8-format hybrid 及 FP8 配方参数重启,(3)前约 100 步将从零开始校准 amax 历史。这相当于 FP8 的「热启动」——权重已经过训练,量化噪声能被很好地容忍。从 FP8 迁移回 BF16:直接移除 FP8 参数即可,主权重已是 FP32,不会损失精度。跨配方迁移(DelayedScaling → Float8CurrentScaling)同样安全——amax 历史会重置,预计约 100 步重新校准。

检查点类型 保存的权重 优化器状态 FP8 元数据 恢复行为
BF16 基线 BF16 或 FP32 主权重 保存 可直接以 BF16 恢复,或切换到 FP8
FP8 延迟缩放 BF16 或 FP32 主权重 保存 通常不保存 恢复后重新建立 amax 历史
FP8 当前缩放 BF16 或 FP32 主权重 保存 按步临时 scale 可安全恢复;前几步重新校准
自定义 FP8 持久化 BF16 或 FP32 主权重 保存 由用户代码显式保存 恢复最快,但格式会与配方绑定
BF16 训练保存检查点
使用 FP8 恢复amax 校准
FP8 训练保存检查点
转换为 HF 格式megatron_to_hf.py
部署TRT-LLM / vLLM
# Step 1: Save BF16 checkpoint (standard Megatron)
# ... training running in BF16 ...
# Checkpoint saved at iteration 10000

# Step 2: Resume with FP8 enabled
python pretrain_gpt.py \
    --load /checkpoints/iter_10000 \
    --fp8-format hybrid \
    --fp8-amax-history-len 1024 \
    --fp8-amax-compute-algo max \
    --bf16  # master weights still BF16
    # amax history starts fresh — first ~100 steps recalibrate

# Step 3: Convert FP8-trained checkpoint to HuggingFace
python tools/checkpoint/convert.py \
    --model-type GPT \
    --loader megatron \
    --saver huggingface \
    --load-dir /checkpoints/fp8_iter_50000 \
    --save-dir /hf_model/
# Output is a standard HF model — FP8 metadata is not included

💾 检查点中保存的内容

  • 主权重(FP32/BF16)
  • 优化器状态(Adam m、v)
  • 当前迭代数
  • 学习率调度状态
  • 随机数生成器状态

🚫 检查点中未保存的内容

  • FP8 缩放因子(恢复时重新计算)
  • amax 历史(从零重新累积)
  • FP8 权重缓存
  • FP8 配方配置
# Optional: persist fp8_meta explicitly for fast resume
def collect_fp8_meta(model):
    meta_state = {}
    for name, mod in model.named_modules():
        if hasattr(mod, "fp8_meta"):
            meta_state[name] = {
                "scale_fwd": mod.fp8_meta["scaling_fwd"].scale.clone(),
                "amax_fwd": mod.fp8_meta["scaling_fwd"].amax_history.clone(),
                "scale_bwd": mod.fp8_meta["scaling_bwd"].scale.clone(),
                "amax_bwd": mod.fp8_meta["scaling_bwd"].amax_history.clone(),
            }
    return meta_state

checkpoint["fp8_meta"] = collect_fp8_meta(model)

# On resume, restore only if recipe and module layout still match
if "fp8_meta" in checkpoint:
    for name, mod in model.named_modules():
        if name not in checkpoint["fp8_meta"]:
            continue
        saved = checkpoint["fp8_meta"][name]
        mod.fp8_meta["scaling_fwd"].scale.copy_(saved["scale_fwd"])
        mod.fp8_meta["scaling_fwd"].amax_history.copy_(saved["amax_fwd"])
关于 FP8 训练最令人宽慰的事实:检查点完全与格式无关。经 FP8 训练的检查点与经 BF16 训练的检查点无法区分——两者都包含 FP32 主权重。你可以在训练的任意时刻自由切换 BF16、FP8 DelayedScaling、FP8 CurrentScaling 和 MXFP8,也可以无需任何特殊处理地转换为任何推理格式(HF、TensorRT-LLM、vLLM)。FP8 是纯粹的训练加速——它在保存的模型中不留任何痕迹。
将 BF16 检查点加载到 FP8 模型时有一个常见陷阱:BF16 state_dict 中不包含 fp8_meta 键(amax_historyscalescale_inv)。Transformer Engine 会优雅地处理这一情况——缺失的 fp8_meta 键会默默重新初始化为默认值(scale = 1.0,amax_history = 全零)。但用户应预料大约 100–500 个优化器步的「重新校准窗口」,即 amax 历史从观测到的激活値重新累积的期间。在此期间,loss 可能出现轻微升高的噪声或小幅临时突刺。调用 load_state_dict 时务必使用 strict=False,以避免缺失 fp8_meta 条目导致的 KeyError。一旦 amax 历史达到饱和(经过 fp8_amax_history_len 步),训练动态即恢复正常。

§21 FP8 × 混合专家模型 (MoE)

MoE 模型将不同的令牌路由到不同的专家 MLP。每个专家都是标准线性层,可通过 Transformer Engine 使用 FP8。挑战在于:不同专家接收不同的令牌子集,激活分布各异,因此每张量的缩放必须是逐专家的,而非全局共享的。

在专家并行(EP)中,每块 GPU 持有一部分专家。FP8 缩放因子本身就是逐模块(即逐 TE Linear)的,因此 EP 天然兼容——每块 GPU 上的每个专家独立维护自己的 amax 历史和缩放因子,无需跨 EP rank 同步缩放因子(而张量并行需要 all-reduce 来同步每张量的 amax)。MoE 的令牌路由导致每步每个专家的批量大小不固定:同一步内某些专家可能接收 10 个令牌,另一些则接收 1000 个,使得逐专家的 amax 波动剧烈。Float8CurrentScaling 对 MoE 更为鲁棒,因为它每步都从实际输入张量计算缩放因子。

路由器为词元打分
分发到专家EP all-to-all
逐专家 FP8 MLP独立 amax / scale
汇聚专家输出
共享合并与残差

逐专家 FP8(路由专家)

每个专家拥有独立的 amax/缩放因子。每步批量大小不固定,基于历史的缩放可能滞后于负载突刺。建议对路由专家使用 CurrentScaling,以确保缩放因子始终反映每步实际输入分布。

共享专家 FP8(DeepSeek 风格)

共享专家每步处理所有令牌,其 amax 分布比路由专家稳定得多。任何缩放配方都适用——共享专家是 MoE 中对 FP8 最友好的组件,也是评估 MoE FP8 时最适合首先启用的候选对象。

# Megatron MoE + FP8 configuration
# Each expert is a TEColumnParallelLinear + TERowParallelLinear pair
# Expert parallelism: experts distributed across EP ranks
# FP8 scales are per-module → per-expert automatically

# Key flags:
#   --num-experts 8
#   --expert-model-parallel-size 4   # EP=4, 2 experts per GPU
#   --fp8-format hybrid
#   --moe-token-dispatcher-type alltoall

# Each expert's te.Linear maintains independent fp8_meta:
# expert_0.linear_fc1.fp8_meta['scaling_fwd'].amax_history
# expert_1.linear_fc1.fp8_meta['scaling_fwd'].amax_history
# ... completely independent scaling per expert
并行策略 FP8 缩放范围 需要跨 rank 同步? 备注
专家并行(EP) 逐专家、逐 GPU 每块 GPU 上的专家独立缩放
张量并行(TP) 逐张量(TP 组内共享) 是(amax all-reduce) TP 分片必须使用相同缩放因子
EP + TP TP 组内逐专家 是(每个专家在 TP 组内同步) 每个专家内部做 TP all-reduce
流水线并行(PP) 逐流水线阶段 PP 边界处上转精度
专家 FFN 形状 每步每专家词元数 BF16 激活字节数 FP8 激活字节数 说明
4096 → 14336 128 ~3.7 MB ~1.8 MB 专家批量较小;元数据开销更显著
4096 → 14336 512 ~14.7 MB ~7.3 MB 常见负载区间;接近理想 2× 节省
8192 → 28672 1024 ~58.7 MB ~29.4 MB 大专家从 FP8 激活压缩中获益最大
MoE 的动态令牌路由对 FP8 而言既是挑战也是机遇。挑战在于:专家负载不均衡意味着某些专家只看到少量令牌,其 amax 不具代表性——一个突然接收大量异常值令牌的「冷专家」可能在某一步产生严重错误的缩放因子。机遇在于:MoE 的专家隔离性意味着某个专家中的 FP8 量化误差不会传播到其他专家,使得 MoE 天然比稠密模型对逐专家 FP8 噪声更具鲁棒性。

一个微妙但重要的内部细节:每个用作专家子层的 TE Linear 模块都拥有自己独立的 fp8_meta 字典,与所有其他专家完全独立。该字典每个 GEMM 方向(前向和反向)各保存三个 FP8 张量:形状为 [amax_history_len, num_gemm_tensors]amax_historyscalescale_inv。对于某步接收词元极少的路由专家,观测到的 amax 可能是异常值甚至为零,导致通过历史缓冲区持续到未来步的陈旧 scale。两个实用缓解方法:(1)对 MoE 模型降低 fp8_amax_history_len(例如由 1024 改为 256),使陈旧历史衰减更快;(2)将 DelayedScaling 与最小缩放下界配合使用(将 scale 阐位到至少 1e-4),防止零负载专家出现除零错误。在专家并行中,跳过节点间的 amax 同步(即 TP 所需的 all-reduce):每个 EP rank 处理自己的专家分片,逐专家 scale 天然是局部的,保留了 EP 的带宽优势。

§22 FP8 微调、LoRA 与 PEFT

FP8 全量微调与 FP8 预训练的工作方式完全相同:加载 BF16 检查点,使用 FP8 参数恢复训练。由于模型已经过预训练,激活分布从第 1 步起就相对稳定,因此热身步数可以更短(50–100 步,而预训练需要 200–500 步)。amax 历史校准速度更快,因为模型不处于随机初始化状态。

FP8 与 LoRA 的交互:LoRA 在冻结的基础权重上添加低秩适配器(A、B 矩阵)。冻结基础权重的前向传播可以使用 FP8 GEMM——权重是静态的,因此 FP8 权重缓存效果最佳(is_first_microbatch 缓存在整个训练过程中有效)。LoRA 适配器矩阵(A: d×r,B: r×d)非常小,以 BF16/FP32 训练——将其量化为 FP8 会适得其反,因为 r 很小(8–128),通过适配器的梯度信号必须保持精确。这种混合方案在保留适配器训练质量的同时,为昂贵的基础模型 GEMM 提供 FP8 张量核吞吐量。

🎯 微调策略:全模型 FP8

适合追求最高适配质量、且模型稳定性足以快速重校准的场景。如果任务分布与预训练数据差异较大,可将 embedding 或 LM head 等敏感模块保留在 BF16。

⚡ 微调策略:FP8 基座 + BF16 适配器

这是 LoRA 或 PEFT 中速度与质量最均衡的方案。冻结骨干网络享受 FP8 张量核吞吐与缓存复用,小规模可训练适配器则在 BF16 或 FP32 中保留更精确的梯度与优化器行为。

加载 BF16 检查点
冻结基础权重
挂载 LoRA 适配器BF16
FP8 转换冻结权重缓存一次
前向:FP8 GEMM(基础)+ BF16 GEMM(适配器)
反向:FP8 dX + BF16 dA、dB
Adam:仅更新适配器
方法 基础权重 适配器权重 前向 GEMM 反向 GEMM 显存(70B)
BF16 全量微调 BF16(训练) N/A BF16 BF16 ~280 GB
FP8 全量微调 FP8(训练) N/A FP8 FP8 ~200 GB
BF16 LoRA BF16(冻结) BF16(训练) BF16 BF16(仅适配器) ~160 GB
FP8 LoRA FP8(冻结,已缓存) BF16(训练) FP8+BF16 FP8(dX)+ BF16(dA,dB) ~140 GB
QLoRA(NF4) NF4(冻结) BF16(训练) 反量化→BF16 BF16(仅适配器) ~40 GB
# Enable FP8 cache reuse for a frozen backbone in LoRA or PEFT
import transformer_engine.pytorch as te

recipe = fp8_recipe
for p in base_model.parameters():
    p.requires_grad = False

with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
    outputs = model(
        input_ids,
        is_first_microbatch=(microbatch_id == 0),
    )

# For frozen weights the cached FP8 copy stays valid across steps.
# Only adapter weights keep updating in BF16 or FP32.
与 FP8 全量微调和 QLoRA 相比,FP8 LoRA 具有独特优势:冻结的基础权重只被量化为 FP8 一次,并在整个训练过程中缓存(而非每个梯度累积周期重算)。这意味着基础模型每步的量化开销为零——FP8 权重缓存命中率实际上达到 100%。结合 BF16 适配器训练,在保持小而关键的适配器梯度全精度的同时,99% 以上的 FLOP 享受 FP8 张量核吞吐量。

🔬 全量 FP8 微调

  • 权重: 所有参数均参与更新;主权重保存为 FP32,GEMM 以 FP8 执行
  • amax/scale: 每步重新校准;从 BF16 检查点热身约 50–100 步
  • 优化器状态: 每个参数均为 FP32 的完整 Adam m、v——与 BF16 训练相同
  • 显存(70B): 约 200 GB——相比 BF16 的节省仅来自 FP8 激活値
  • 适合: 任务分布接近预训练时求最佳适配质量;若任务分布偏离较大,建议 LM head 保留 BF16

⚡ LoRA + FP8 基座

  • 基座权重: 以 FP8 冻结——在第一个微批次仅量化一次并缓存整个训练过程;每步额外开销为零
  • amax/scale: 在缓存填充时一次性计算(is_first_microbatch=True);权重冻结后永不再更新
  • LoRA 适配器: A、B 矩阵(秩 r = 8–128)以 BF16/FP32 训练;优化器状态仅针对小适配器参数
  • 显存节省叠加: FP8 基座激活値(~2× vs BF16)× 仅适配器优化器(~100× 更少参数)= 最低总显存占用
  • 显存(70B): 约 140 GB——是 BF16 全量微调的 1/3,是 BF16 LoRA 的 1/1.4

§23 为什么 FP8 有效:量化噪声与收敛性

E4M3 格式的 FP8 仅有 256 个可表示值,而 FP16 有 65536 个。直觉上,可表示值少 256 倍应该会使训练灾难性退化。然而实验结果表明精度损失小于 0.5%。关键在于 FP8 的舍入操作每步对权重和激活值引入了小幅随机扰动。这种噪声均值为零(对最近偶数舍入而言),方差与量化步长成正比——在数学上类似于向梯度添加高斯噪声,这是一种经过深入研究的正则化技术。噪声防止优化器过拟合损失曲面中的微小特征。

信噪比(SQNR,Signal-to-Quantization-Noise Ratio):对于动态范围为 D、使用适当缩放的 FP8 张量,SQNR ≈ 6.02 × 尾数位数 + 1.76 dB。E4M3(3 位尾数):SQNR ≈ 20 dB;E5M2(2 位尾数):SQNR ≈ 14 dB;BF16(7 位尾数):SQNR ≈ 44 dB。关键洞察:20 dB 对随机优化已经足够,因为 SGD 本质上就是有噪声的——批量采样噪声通常本身就引入 10–15 dB 的噪声。FP8 量化噪声与已有且被容忍的优化噪声相当甚至更小。

一个紧凑的推导方式是:设缩放后的张量以量化步长 Δ 被量化。在常见的高分辨率近似下,舍入噪声可视为分布在 [-Δ/2, Δ/2] 的均匀噪声,因此其方差为 Δ²/12。若信号方差为 σ²,则 SQNR = 10 log10(σ² / (Δ²/12))。对于已归一化的浮点尾数,每多 1 个尾数位,Δ 就大致减半,从而 SQNR 约提升 6.02 dB。这正是经验公式的来源。

收集参考张量
执行 FP8 缩放与量化
反量化并计算误差
估计信号功率与噪声功率
输出 dB 形式的 SQNR

FP8 为何理论上应该失效

256 个可表示值 vs FP16 的 65536 个。3 位尾数意味着每个值约 12.5% 的相对误差。数千步的梯度误差累积。极端异常值在不裁剪的情况下无法表示。

FP8 为何实际上有效

SGD 本来就有噪声(小批量方差)。量化噪声相对于梯度噪声微乎其微。缩放消除了裁剪(最严重的误差模式)。噪声起正则化作用,防止陷入尖锐最小值。逐张量缩放确保数值适配 FP8 范围。

格式 可表示值数量 尾数位数 SQNR(dB) 相对误差 训练是否足够?
FP32 ~40 亿 23 140 ~0.000001% 是(金标准)
BF16 65536 7 44 ~0.8% 是(标准)
FP16 65536 10 62 ~0.1% 是(需损失缩放)
E4M3 256 3 20 ~12.5% 是(需张量缩放)
E5M2 256 2 14 ~25% 是(用于梯度)
INT8 256 7(均匀量化) 50 ~0.4% 否(无反向传播)
import torch

def measure_sqnr(tensor, fp8_tensor, scale):
    """Measure Signal-to-Quantization-Noise Ratio in dB"""
    # Dequantize: fp8_value * scale → approximate original
    reconstructed = fp8_tensor.float() * scale
    noise = tensor.float() - reconstructed
    signal_power = (tensor.float() ** 2).mean()
    noise_power = (noise ** 2).mean() + 1e-10
    sqnr_db = 10 * torch.log10(signal_power / noise_power)
    return sqnr_db.item()

# Typical results for a well-scaled transformer layer:
# Activations (E4M3): SQNR ≈ 18-22 dB
# Weights (E4M3):     SQNR ≈ 22-28 dB (smoother distribution)
# Gradients (E5M2):   SQNR ≈ 12-16 dB (wider range needed)
FP8 训练有效的最深层原因是:基于 SGD 的优化本质上就是一个有噪声的过程。小批量梯度估计相对于真实梯度已经引入了约 15 dB 的噪声。FP8 在激活值和权重上约 20 dB 的 SQNR 意味着量化噪声与已有且被容忍的优化噪声相当甚至更小。换言之,FP8 并没有引入本质上全新的误差来源——它只是略微放大了优化器本就被设计来应对的现有噪声底线。这正是 FP8 训练不需要任何算法改动(无需新优化器,无需课程学习)的原因——标准训练配方能透明地吸收 FP8 噪声。

FP8 量化的理论误差分析:对于均匀分布在某个指数桶内的值,最坏情况下的相对量化误差以 1/(2^(m+1)) 为界,其中 m 是尾数位数。对于 E4M3(m=3),单元素最坏相对误差为 1/16 = 6.25%;对于 E5M2(m=2),为 1/8 = 12.5%。然而,这些单元素界限大大高估了对 K 项点积的实际影响。由于 K 个独立元素的量化舍入误差近似满足均值为零的独立同分布,在求和时会部分相消:点积的期望绝对误差以 sqrt(K) × ε_elem 而非 K × ε_elem 增长。具体地说,对于 4096 维隐状态(K=4096),sqrt(K) 放大因子仅为 64,而朴素相乘给出的是 4096。正是这种 sqrt(K) 累积界在数学上解释了为何大张量上的 FP8 GEMM 能温和退化:输出神经元的相对误差随输入元素维度增大而以 ε_elem / sqrt(K) 缩小。

核心公式:对于点积 y = ∑k ak bk,当两个向量均量化为 m 位尾数时,期望相对误差 |Δy| / |y| ≤ ε / sqrt(K),其中 ε = 1/(2^(m+1))。对于 E4M3,K=4096 时:6.25% / 64 ≈ 0.098%。这一低于 0.1% 的相对 GEMM 误差远小于小批量采样引入的逐步梯度噪声,这从理论上解释了为何 FP8 训练在大模型上的收敛性能与 BF16 匹配。该公式还意味着:更宽(更大 K)的层在 FP8 下比窄层数值上更稳定。

§24 FP8 性能剖析与优化

FP8 加速来自三个来源:(1)张量核吞吐量——H100 FP8 张量核的矩阵乘 FLOPS 是 BF16 的 2 倍;(2)显存带宽——FP8 权重/激活值体积减半,带宽受限操作的 HBM 读取速度提升 2 倍;(3)通信——FP8 梯度 all-reduce 数据量减半。实际加速效果取决于模型的瓶颈所在。小模型(<1B)通常受显存带宽限制,加速约 1.6–1.8 倍;大模型(>10B)受计算限制,加速约 1.4–1.6 倍。并行策略复杂的超大模型可能加速更少,因为通信成为瓶颈。

从 Nsight Systems 性能分析报告中提取的关键指标:FP8 GEMM 比例(目标:总 GEMM 时间中 >90% 使用 FP8 张量核);缩放开销(amax 规约和缩放计算所用时间,应小于步时间的 2%);通信重叠(梯度 all-reduce 是否与计算重叠?);显存吞吐量(FP8 应提高有效 HBM 利用率)。使用 nsys profile 配合 NVTX 标记捕获 GPU 核函数跟踪,查找名称中含有 fp8e4m3 的核函数——GEMM 核函数名称中缺少 FP8 标记表明存在静默 BF16 回退。

🧭 系统级时间线工具

Nsight Systems 最适合看整步执行解剖:GPU 核函数、NCCL 重叠、NVTX 区间以及流水线气泡。当你怀疑存在静默回退或 TP、PP、DP 通信重叠不足时,应优先使用它。

🔬 核函数级工具

Nsight Compute 与 TE 调试日志更适合回答单核函数问题:实际 occupancy、张量核利用率、内存事务以及精确的调度决策。通常应先用 Systems 找到慢层或慢核,再用它们继续下钻。

问题 症状 诊断 修复
静默 BF16 回退 FP8 比预期慢 nsys 在 TE 层中显示 BF16 GEMM 核函数 检查 NVTE_DEBUG,验证 head_dim/seq_len 兼容性
缩放计算开销 GEMM 之间有大量小核函数 nsys 显示频繁的 amax_and_scale 核函数 使用 CurrentScaling(融合缩放计算)或 MXFP8
FP8 转换开销 每次 GEMM 前有额外核函数 nsys 显示 cast_to_fp8 核函数 确保权重缓存已激活(is_first_microbatch)
通信瓶颈 计算阶段之间 GPU 空闲 nsys 在 all-reduce 期间显示空白 启用梯度分桶,重叠通信与计算
小批量低效 GPU 利用率低 nsys 显示短暂的 GEMM 核函数伴有间隙 增大微批量大小以充分利用张量核
# Profile a Megatron FP8 training step with Nsight Systems
nsys profile -w true \
    -t cuda,nvtx \
    --capture-range=cudaProfilerApi \
    -o fp8_profile \
    python pretrain_gpt.py \
        --fp8-format hybrid \
        --profile \
        --profile-step-start 10 \
        --profile-step-end 12

# Analyze: open fp8_profile.nsys-rep in Nsight Systems GUI
# Look for:
#   - Kernel names containing "fp8" or "e4m3" → FP8 GEMMs
#   - Kernel names with "gemm" but no "fp8" → BF16 fallback
#   - "amax" kernels → scale computation overhead
#   - "cast" kernels → FP8 quantization overhead

# Enable TE debug logging for kernel dispatch info
NVTE_DEBUG=1 python pretrain_gpt.py --fp8-format hybrid 2>&1 | \
    grep -E "(FP8|fallback|dispatch)" | head -50
# Add NVTX ranges around FP8-sensitive regions
import torch
import nvtx

def fp8_step(model, batch):
    with nvtx.annotate("forward.fp8", color="green"):
        output = model(**batch)
    with nvtx.annotate("loss.backward", color="red"):
        output.loss.backward()
    with nvtx.annotate("optimizer.step", color="blue"):
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

# Nsight Systems will show these names directly on the timeline.
Nsight 采集2-3 步
识别 GEMM 核函数
分类 FP8 与 BF16
测量缩放开销
检查通信重叠
优化修复回退 / 调整批量 / 启用缓存
对 FP8 训练而言,最有效的单一性能调试步骤是:对 2–3 个训练步运行 Nsight Systems,然后检查 FP8 GEMM 命中率。在理想的 FP8 运行中,>95% 的 GEMM 时间应使用 FP8 张量核。如果看到大量 BF16 GEMM 时间,说明静默回退正在蚕食你的加速收益。NVTE_DEBUG=1 日志会精确告诉你哪些层发生了回退及其原因——在调整其他任何东西之前先修复这些问题。注意力层中的单个 BF16 回退就能将整体 FP8 加速从 1.5 倍削减到 1.1 倍。

对 FP8 GEMM 区域进行精细的 NVTX 标注:使用 torch.cuda.nvtx.range_push/pop 包妆每个 TE 层的前向和反向传播过程。这将使 Nsight Systems 以阶段颜色显示每层的计时条,轻松发现哪些层运行较慢或退化为 BF16。将此与 TE 内置的逐层 FP8 统计功能结合使用:在启动训练前设置环境变量 NVTE_FP8_STATS=1,Transformer Engine 将每隔 N 步输出每层的 amax 值、缩放因子及溢出/下溢次数表格。此外,te.pytorch.get_fp8_context_id() 返回当前 FP8 Recipe 的上下文标识符,对于记录首次观测到数値异常时激活的是哪个 Recipe 非常有用。

import torch
import transformer_engine.pytorch as te
import transformer_engine.common as te_common

# ── NVTX annotation for per-layer FP8 GEMM profiling ──────────────────────────
class NVTXFp8Layer(te.TransformerLayer):
    def forward(self, hidden_states, **kwargs):
        ctx_id = te.get_fp8_context_id()  # e.g. recipe hash for logging
        torch.cuda.nvtx.range_push(
            f"TE.forward|ctx={ctx_id}|layer={self._layer_number}"
        )
        out = super().forward(hidden_states, **kwargs)
        torch.cuda.nvtx.range_pop()
        return out

# ── NVTE_FP8_STATS: per-layer amax / scale / overflow dump ────────────────────
# Set before launching training:
#   export NVTE_FP8_STATS=1          # print per-layer stats every 100 steps
#   export NVTE_FP8_STATS_FREQ=50    # override the print frequency
#
# Example console output (one row per TE layer per step):
#   [TE FP8 Stats | step=100 | layer=0 | fwd]
#     input  amax=1.8432   scale=0.5430  overflows=0  underflows=0
#     weight amax=0.3107   scale=3.2174  overflows=0  underflows=0
#     output amax=2.1054   scale=0.4750  overflows=0  underflows=2  ← warning!

# ── Programmatic access to per-layer FP8 state ────────────────────────────────
def log_fp8_stats(model):
    for name, module in model.named_modules():
        if hasattr(module, 'fp8_meta'):
            meta = module.fp8_meta
            for key in ('scaling_fwd', 'scaling_bwd'):
                if key in meta:
                    amax = meta[key].amax_history[0]  # most recent window
                    scale = meta[key].scale
                    print(f"[FP8] {name}.{key}: amax={amax.max():.4f} scale={scale.min():.4f}")

§25 分布式训练中的 FP8 通信优化

在标准 BF16 数据并行中,梯度全归约每个参数传输 2 字节。使用 FP8 后,梯度可以用 E5M2 格式(1 字节)通信,将全归约带宽减半。Megatron 通过 --tp-comm-overlap 和 TE 内置通信钩子支持此功能。梯度在全归约前量化为 FP8,之后反量化——全归约本身直接操作原始字节。

对于张量并行列并行层,激活输出的全收集可以用 FP8 而非 BF16 进行,将 TP 组内 GPU 间带宽减半。当启用 --tp-comm-overlap 时,TE 支持 FP8 TP 通信。关键细节:缩放因子必须与 FP8 数据一同广播,以便接收端正确反量化。对于张量并行行并行层,归约散射的处理方式不同——先在 BF16/FP32 中完成归约散射(保留求和精度),然后在本地将结果转换为 FP8。

操作 BF16(字节/参数) FP8(字节/参数) 节省 约束
DP 梯度全归约 2 1(E5M2) 2x 需广播缩放因子
TP 全收集(列并行) 2 1(E4M3) 2x 缩放因子须随数据传输
TP 归约散射(行并行) 2 2(保持 BF16) 求和精度要求 BF16
PP 发送/接收 2 1 2x 流水线阶段边界处上转精度
CP 环形全归约(KV) 2 1 2x 每块独立缩放因子

FP8 安全通信

梯度全归约(E5M2)、TP 全收集(E4M3)、PP 发送/接收、点对点传输。这些是逐元素操作或简单的数据搬运——相对于已量化的张量,FP8 是无损的。

必须使用 BF16 的通信

归约散射(需要求和)、跨节点梯度累积、任何涉及对通信值进行算术运算的操作。直接对 FP8 值求和会造成不可接受的精度损失。

计算 GEMM(FP8)
量化梯度(E5M2)
全归约(半带宽)
反量化
累积(FP32)
优化器步骤
# Megatron FP8 communication optimization flags
# --tp-comm-overlap              # overlap TP comm with compute
# --tp-comm-overlap-cfg file.yaml # fine-grained overlap config
# --overlap-grad-reduce          # overlap gradient reduce with backward
# --overlap-param-gather         # overlap param gather with forward
#
# TE handles FP8 TP communication internally:
# Column-parallel forward: all-gather output in FP8
# Row-parallel forward: reduce-scatter stays BF16 (sum accuracy)
# Backward: symmetric communication pattern
#
# Example: 8xH100 with TP=8, FP8 TP comm
# BF16 all-gather for 7B model: ~3.5 GB per layer
# FP8 all-gather for 7B model:  ~1.75 GB per layer -> 2x less NVLink traffic
# Broadcast FP8 scales across DP ranks before dequantization
import torch.distributed as dist

def broadcast_scale_and_grad(fp8_grad, scale, dp_group, src=0):
    scale_buf = scale.clone()
    dist.broadcast(scale_buf, src=src, group=dp_group)
    dist.broadcast(fp8_grad, src=src, group=dp_group)
    dequant_grad = fp8_grad.float() * scale_buf
    return dequant_grad

# In practice use the same source rank or all-gather the scale tensor first
# so every DP rank reconstructs identical gradient magnitudes.
FP8 通信优化最大的收益不在于原始带宽节省,而在于计算通信重叠效率。当通信在一半时间内完成时,造成流水线气泡的“通信尾延迟”也相应缩短。对于 TP=8 的 4 阶段流水线,FP8 通信可将流水线气泡开销从约 12% 降至约 7%,这一收益会随规模扩大而叠加。这往往比张量核吞吐量提升带来更大的实际加速效果。

一个关键区别:对梯度进行 FP8 全归约虽然将节点间流量减半,但在 GEMM 量化误差的基础上引入了第二层量化误差。这种叠加误差不可忽视:一个在 GEMM 前已经经过 E5M2 量化的梯度,在全归约前再次经过 E5M2 量化,将累积两层尾数截断,进一步降低有效 SQNR。因此,Megatron-LM 目前不对 DP 梯度全归约使用 FP8:梯度在 DP 全归约前会被上转为 BF16 或 FP32,保留已累积梯度信号的全精度。这是一个镜鑿的设计决策,以通信带宽换取数値安全性。

前沿研究探索了带误差反馈的 FP8 全归约:微软的 1-bit Adam 及其后续关于 FP8/INT8 梯度通信的工作(如微软研究院的 "FP8-LM")证明:如果每个节点维护一个残差误差缓冲区——在本地累积量化舍入误差并在下一步加回——则 FP8 全归约可以在大规模 LLM 训练中匹配 BF16 的收敛效果。该机制是一个误差反馈循环:在第 t 步,将 (g + residual) 量化为 FP8 并通信,然后更新 residual ← (g + residual) − dequantize(FP8(g + residual))。这保证了长期通信梯度的总和与 BF16 情况完全一致。截至 2026 年初,Megatron 尚未将此技术合并入主分支,但微软 Azure 多次大规模训练已对该技术实现了生产验证。

§26 FP8 × 上下文并行与长序列

上下文并行(CP)将序列维度分割到多个 GPU 上——每个 GPU 处理序列的一部分。对于注意力机制,这需要跨 GPU 的 KV 交换(环形注意力模式)。当 KV 块以 FP8 发送时,FP8 可将此通信量减小 2 倍。在自回归推理中,KV 缓存随序列长度线性增长,通常成为显存瓶颈——以 FP8 存储 KV 缓存可将缓存显存减半,这对长上下文模型(128K+ 词元)至关重要。

对于超长序列(32K–1M 词元),FP8 注意力机制面临更大挑战:(1)注意力 logit 中出现更极端的异常值,需要更激进的缩放;(2)softmax 概率可能极小,在 E4M3 中存在下溢风险;(3)量化误差随序列长度累积增长。Megatron 的 CP 实现使用可配置重叠的环形注意力——CP rank 之间发送的 KV 张量可以是 FP8 格式,但由于序列不同部分的激活值量级不同,各块的缩放因子可能有所差异。

针对长上下文 FP8 注意力中的 softmax 风险,常见缓解策略是把最脆弱的数值部分保留在更高精度:指数化前先减去逐行最大值,用 FP32 累积 log-sum-exp,当序列长度超过已验证范围时把注意力反向保留在 BF16,并在必要时对极端 attention logit 做裁剪或重新缩放。生产中常见的折中方案是:QKV 和 V 投影用 FP8,softmax 统计量保留 FP32,注意力反向使用 BF16。

序列长度 注意力 logit 范围 Softmax 最小值 FP8 注意力风险 建议
2K–4K 适中 ~1e-3 FP8 注意力安全
8K–32K 较宽 ~1e-6 仔细监控 amax
64K–128K 很宽 ~1e-10 建议 BF16 反向传播
256K–1M 极端 ~1e-15 极高 考虑完全使用 BF16 注意力

长序列中 FP8 的优势

KV 缓存显存减半(128K+ 时至关重要),CP 通信带宽减半(环形注意力 KV 交换),MLP 层不受序列长度影响(逐词元计算)。

长序列中 FP8 的风险

注意力 logit 范围随序列长度增大,softmax 精度降低(概率值极小),长因果链中误差累积,极长序列可能需要回退到 BF16 注意力。

长序列(128K 词元)
分割到 CP rank(各 32K)
环形注意力:发送 FP8 KV 块
本地 FP8 注意力计算
聚合
MLP(FP8)逐词元,不受序列长度影响
# Megatron Context Parallelism + FP8
# --context-parallel-size 4        # CP=4, sequence split across 4 GPUs
# --seq-length 131072              # 128K token sequence
# --fp8-format hybrid
#
# For very long sequences, consider:
# NVTE_FP8_DPA_BWD=0              # BF16 backward attention (safer)
#
# KV cache FP8 for inference (TensorRT-LLM style):
# KV cache per layer per token: 2 * hidden_dim * num_kv_heads / num_heads
# BF16: 2 bytes per element -> FP8: 1 byte per element
# For LLaMA-70B at 128K context:
#   BF16 KV cache: ~40 GB
#   FP8 KV cache:  ~20 GB  -> fits on single GPU
FP8 在长上下文模型中最重要的应用不是注意力计算——而是 KV 缓存。一个 70B 模型在 128K 上下文下需要约 40 GB 的 BF16 KV 缓存,仅此一项就占用了 H100 80 GB 显存的一半。FP8 KV 缓存将其压缩至约 20 GB,有效将显存内能容纳的最大上下文长度翻倍。对于推理服务,这直接转化为每个 GPU 能并发处理的长上下文请求数量翻倍。极长序列下的注意力精度权衡可通过 BF16 回退来管理;而显存节省则是不可或缺的。

短上下文(≤4K):FP8 特性

注意力显存占用很小,MLP/QKV 投影的 FP8 GEMM 主导总计算量,线性加速直接可得。各步 amax 稳定,fp8_mha=False 默认安全。这是 FP8 的最佳采用点:风险最低,加速最大。

长上下文(≥16K):FP8 特性

注意力显存现占主导,fp8_mha=True 对控制显存至关重要。amax 方差增大,历史窗口需扩至 1024+。考虑序列并行降低单 rank amax 方差。64K+ 时需 BF16 反向回退。

§27 Hopper FP8 Tensor Core 微架构

H100 SM 包含第四代张量核。每个张量核每周期可用 FP8 输入执行 16×8×32 MMA,累加器为 FP32。FP32 累加防止单 tile 内误差累积,实现 BF16 MMA 吞吐量 2 倍。

前向使用 E4M3,反向使用 E5M2。FP8→FP32→输出流水线确保计算密集阶段全精度。下转在完整点积后发生。

维度 / 属性 建议对齐 重要原因 不满足时的后果
M 维 16 的倍数 匹配张量核 tile 高度 需要更多填充或标量尾处理
N 维 8 的倍数 匹配 MMA 输出宽度 内核可能走较慢的残块路径
K 维 FP8 下建议为 32 的倍数 直接匹配 16×8×32 FP8 MMA 张量核效率下降或触发回退
内存地址对齐 至少 16 字节对齐 支持向量化全局或共享内存加载 会产生额外内存事务
GPU FP8 TFLOPS BF16 TFLOPS FP8/BF16 显存带宽 (TB/s)
H100 SXM 19799892.0x3.35
H100 PCIe 15137562.0x2.0
H200 19799892.0x4.8
B100 ~3500~17502.0x8.0
B200 ~4500~22502.0x8.0

张量核做什么

每周期 16×8×32 MMA,FP8 输入,FP32 累加,warp 协同执行,吞吐量是 BF16 的 2 倍。

软件必须做什么

缩放管理、amax 追踪、格式选择、累加器下转精度、权重缓存。

从共享内存加载 FP8 tile
Warp 级 MMA (16×8×32)
FP32 部分积
跨 K 维累积 (FP32)
下转输出 (BF16/FP8)
写入全局内存
# Pseudocode: FP8 Tensor Core MMA (single cycle per warp)
# A_tile: [16, 32] E4M3 | B_tile: [32, 8] E4M3 | C_accum: [16, 8] FP32
for i in range(16):
    for j in range(8):
        for k in range(32):
            C_accum[i,j] += float32(A_tile[i,k]) * float32(B_tile[k,j])
# Individual products promoted to FP32 before addition
# After all K tiles accumulated:
output[i,j] = bfloat16(C_accum[i,j])  # downcast for output
FP8 训练中被低估的英雄是 FP32 累加器,而非 FP8 格式本身。如果累加器以 FP8 或 FP16 工作,K=4096 的 GEMM 中复合舍入误差将为灾难性的。通过将累加器保持在 FP32,硬件确保量化误差在逐元素层面保持有界。FP8 在大型 GEMM 中“开箱即用”的原因就在于此。
特性 A100 (Ampere) H100 (Hopper) B200 (Blackwell)
FP8 张量核 有(4代) 有(5代)
FP8 格式 N/A E4M3 + E5M2 MXFP8 (E4M3 + E5M2)
缩放方式 N/A 逐张量(软件) 逐块 / MXFP8(硬件)
FP8 峰値 TFLOPS N/A 3958(SXM5,稀疏性) ~9000
HBM 带宽 2.0 TB/s 3.35 TB/s 8.0 TB/s
TDP 400 W 700 W 1000 W

§28 FP8 大规模训练:生产实战经验

某些 FP8 问题只在大规模时才会显现。跨节点 amax 不一致导致 DP rank 间梯度幅度失衡。罕见词元异常值可导致损失突刺。

生产环境热身配方:BF16 热身 500–1000 步,保守初始 margin:0,长历史窗口:1024–2048,梯度裁剪 1.0,定期监控 amax。

故障模式 规模阈值 症状 根因 缓解措施
DP rank 间 amax 发散 >64 GPU 损失缓慢漂移 各 rank 独立维护 amax 历史 定期同步 amax(每 1000 步)
罕见词元异常值 >10T 词元 损失突然突刺后恢复 激活值幅度为正常值的 100 倍 更长历史窗口(2048+),梯度裁剪
FP8 内核版本不匹配 异构集群 部分节点出现 NaN 不同节点使用不同版本的 TE/cuDNN 锁定 TE + cuDNN 版本,容器化部署
流水线气泡放大 PP>8 阶段 吞吐量低于预期 流水线气泡期间进行 FP8 缩放计算 将缩放计算与通信重叠
检查点恢复时 amax 不匹配 任意规模 恢复后损失突刺 amax 历史未保存,冷启动 恢复后延长热身(200+ 步)

生产规模下有效的做法

保守的 FP8 配方,延长 BF16 热身,梯度裁剪 1.0,定期 amax 监控,容器化部署。

生产规模下容易失效的做法

激进的 margin 遇罕见异常值,短历史窗口遇突刺数据,节点间 TE 版本混用,无监控,不做热身直接恢复。

开发阶段单节点,快速迭代
验证阶段8 节点,完整配方调优
生产阶段100+ 节点,保守配方
后训练FP8->BF16,INT4 量化,部署
# Production FP8 monitoring hook for Megatron
class FP8MonitorHook:
    """Log FP8 health metrics every N steps"""
    def __init__(self, model, log_interval=100):
        self.model = model
        self.log_interval = log_interval
        self.prev_amax = {}
    def check(self, step):
        if step % self.log_interval != 0: return
        for name, mod in self.model.named_modules():
            if not hasattr(mod, 'fp8_meta'): continue
            fwd = mod.fp8_meta['scaling_fwd']
            amax = fwd.amax_history[-1].item()
            if name in self.prev_amax:
                ratio = amax / (self.prev_amax[name] + 1e-7)
                if ratio > 5.0 or ratio < 0.2:
                    print(f"[FP8-ALERT] {name}: amax changed {ratio:.1f}x")
            self.prev_amax[name] = amax
# Periodically synchronize amax across DP ranks
import torch.distributed as dist

def sync_amax_every(model, dp_group, step, interval=1000):
    if step % interval != 0:
        return
    for _, mod in model.named_modules():
        if not hasattr(mod, "fp8_meta"):
            continue
        hist = mod.fp8_meta["scaling_fwd"].amax_history
        current = hist[-1].clone()
        dist.all_reduce(current, op=dist.ReduceOp.MAX, group=dp_group)
        hist[-1].copy_(current)

# Run after optimizer.step() or at the end of the training step.
生产 FP8 训练中最昂贵的教训:256 GPU 上第 50000 步出现的 0.3% 损失退化,代价远超 FP8 加速收益。生产原则:先匹配 BF16 损失曲线,再优化吞吐量。从最保守的配方开始,逐一验证并启用激进优化。
第 1 步:BF16 基线训练——运行 500+ 步,确认损失稳定并记录吸吐量基线
第 2 步:开启 FP8 hybrid + 延迟缩放,保守配置:margin=0history_len=1024、梯度裁剪 1.0
第 3 步:监控每层 amax(500 步)——每 100 步记录;变化 >5× 则告警
第 4 步:如果 amax 啑峰,调整 margin→2、history_len→2048,确保每 1000 步跨 DP rank 同步 amax
第 5 步:对比 BF16 vs FP8 评估指标——期望困惑度及下游基准降低 <0.1%
第 6 步:长序列工作负载(≥16K),开启 FP8 注意力(fp8_mha=True)并验证 KV 缓存显存减小 2×