强化学习 · 策略梯度 · 约 10 分钟阅读

GAE

广义优势估计:从 Critic 到偏差-方差权衡

📑 目录

  1. Critic 是什么?
  2. 优势函数 A(s,a)
  3. TD 残差:单步优势估计
  4. 蒙特卡洛回报:N 步优势
  5. GAE:广义优势估计
  6. 交互可视化
  7. λ 调参实战指南
  8. 常见坑点与 FAQ
  9. GAE 在 PPO 中的位置
λ = 0高偏差 / 低方差
λ ∈ (0,1)偏差-方差折中
λ = 1低偏差 / 高方差

1. Critic 是什么?

在强化学习里,智能体(Agent)要在环境中做决策。Actor-Critic 框架把智能体拆成两个部分:

🎭 Actor(演员)

策略网络 \( \pi_\theta(a \mid s) \),给定状态 \(s\),输出各动作的概率分布。它是真正做决策的那个。

🧠 Critic(评论家)

价值网络 \( V_\phi(s) \),给定状态 \(s\),估计从该状态出发能拿到的期望折扣回报。它不做决策,只负责"打分"。

Critic 估计的价值函数定义为:

\[ V^\pi(s) = \mathbb{E}_\pi\!\left[\sum_{k=0}^{\infty} \gamma^k r_{t+k} \;\Bigg|\; s_t = s\right] \]

即:从状态 \(s\) 出发,按照策略 \(\pi\) 执行动作,未来所有奖励的折扣之和的期望值。\(\gamma \in [0,1)\) 是折扣因子,控制未来奖励相对于即时奖励的重要程度。

🍜 类比: 你在一家餐厅点菜(Actor 选动作)。Critic 就像一位资深食评人,在你点菜之前就告诉你"从这桌的当前情况看,你大概能吃到 8 分满意度"。有了这个基准,你就能判断某道菜是否比平均水平更好。

为什么需要 Critic?在纯策略梯度(REINFORCE)中,我们用整个轨迹的回报 \(G_t\) 来估计梯度,方差极高(噪声大,训练不稳定)。Critic 提供一个基准值(baseline),大幅降低方差,同时训练更快更稳定。

2. 优势函数 A(s, a)

光有 \(V(s)\) 还不够。我们真正关心的是:某个具体动作 \(a\) 比"平均水平"好多少?这正是优势函数的定义:

\[ A^\pi(s, a) = Q^\pi(s, a) - V^\pi(s) \]

其中 \(Q^\pi(s,a)\) 是动作价值函数(Q 函数):在状态 \(s\) 执行动作 \(a\),然后继续按策略 \(\pi\) 行动的期望总回报:

\[ Q^\pi(s, a) = \mathbb{E}_\pi\!\left[\sum_{k=0}^{\infty} \gamma^k r_{t+k} \;\Bigg|\; s_t=s,\, a_t=a\right] \]
直觉解读: \(A(s,a) > 0\):这个动作比平均水平更好,应该增大它的概率。  \(A(s,a) < 0\):这个动作比平均水平差,应该减小它的概率。  \(A(s,a) = 0\):刚好等于平均水平。

策略梯度定理可以用优势函数表达(这是 Actor-Critic 的核心):

\[ \nabla_\theta J(\theta) = \mathbb{E}\!\left[\nabla_\theta \log \pi_\theta(a_t \mid s_t) \cdot A^\pi(s_t, a_t)\right] \]

问题来了:\(A^\pi(s,a)\) 无法直接计算(需要知道真实的 \(Q\) 和 \(V\))。我们只能用 Critic 来估计它——这就引出了下面的内容。

📊 数值例子:网格世界
假设智能体在一个网格世界中,当前状态 \(s\) 有三个动作可选:
动作 a \(Q^\pi(s, a)\) \(V^\pi(s)\) \(A^\pi(s, a) = Q - V\) 含义
→ 右移8.57.0+1.5比平均好,增大概率
↑ 上移7.07.00.0刚好等于平均
← 左移5.57.0−1.5比平均差,减小概率
这里 \(V^\pi(s)=7.0\) 是所有动作的期望价值。优势函数精确地告诉策略梯度:把概率从 ← 移到 → 上!

3. TD 残差:单步优势估计

最简单的优势估计:只走一步,然后用 Critic 估计剩余价值。这就是时序差分(TD)残差:

\[ \delta_t = r_t + \gamma V(s_{t+1}) - V(s_t) \]

解读:\(r_t + \gamma V(s_{t+1})\) 是对 \(Q(s_t, a_t)\) 的一步估计(TD 目标),减去 \(V(s_t)\) 就得到了优势估计 \(\hat{A}_t \approx \delta_t\)。

📊 数值例子:计算 TD 残差
假设一条轨迹的前 3 步,\(\gamma = 0.99\):
时刻 t \(r_t\) \(V(s_t)\) \(V(s_{t+1})\) TD 目标 \(r_t + \gamma V(s_{t+1})\) \(\delta_t\)
t=0+1.05.05.21.0 + 0.99×5.2 = 6.1486.148 − 5.0 = +1.148
t=1−0.55.24.8−0.5 + 0.99×4.8 = 4.2524.252 − 5.2 = −0.948
t=2+0.34.85.10.3 + 0.99×5.1 = 5.3495.349 − 4.8 = +0.549
解读:t=0 时 \(\delta_0 = +1.148\),说明这一步的实际回报比 Critic 预期的好得多——这个动作很棒!t=1 时 \(\delta_1 = -0.948\),说明这一步表现很差。

✅ 优点

  • 方差低(只依赖一步奖励)
  • 计算简单高效
  • 可以在线更新

❌ 缺点

  • 偏差高(依赖 Critic 的精度)
  • 如果 V 不准,估计会很差
  • 信号传播慢(短视)
🔭 比喻: TD 残差就像只看今天的天气来预测这个月的气候——快,但容易出错(偏差大)。

4. 蒙特卡洛回报:N 步优势

走多步再用 Critic 截断,得到 N 步回报:

\[ G_t^{(n)} = r_t + \gamma r_{t+1} + \cdots + \gamma^{n-1} r_{t+n-1} + \gamma^n V(s_{t+n}) \]

对应的 N 步优势估计:

\[ \hat{A}_t^{(n)} = G_t^{(n)} - V(s_t) \]

当 \(n \to \infty\) 时,就退化为纯蒙特卡洛(不再用 Critic 截断),此时偏差最小,但方差最大。

方法 步数 n 偏差 方差 依赖 Critic?
TD(0)n=1 是(强)
N-stepn=2~10 是(中)
蒙特卡洛n=∞

这个表格揭示了一个两难困境:步数越多,偏差越小,但方差越大。GAE 用一个优雅的方式解决这个问题。

📊 数值例子:1 步 vs 3 步 vs 蒙特卡洛
沿用上一节的轨迹数据,\(\gamma=0.99\),比较不同步数对 \(t=0\) 的优势估计:
方法 计算过程 \(\hat{A}_0\)
1-step (TD) \(r_0 + \gamma V(s_1) - V(s_0) = 1.0 + 0.99 \times 5.2 - 5.0\) +1.148
3-step \(r_0 + \gamma r_1 + \gamma^2 r_2 + \gamma^3 V(s_3) - V(s_0)\)
= 1.0 + 0.99×(−0.5) + 0.98×0.3 + 0.97×5.1 − 5.0
+0.791
蒙特卡洛 \(G_0 - V(s_0)\),假设真实回报 \(G_0 = 5.7\) +0.700
注意三种估计给出了不同的值!1-step 最依赖 Critic,偏差最大;蒙特卡洛最接近真实,但波动很大。这就是为什么我们需要 GAE。

5. GAE:广义优势估计

GAE(Schulman et al., 2016)的核心思想:不选一个固定的步数 n,而是对所有 n 步估计做指数加权平均!

\[ \hat{A}_t^{\text{GAE}(\gamma,\lambda)} = \sum_{l=0}^{T-t-1} (\gamma\lambda)^l \,\delta_{t+l} \]

其中 \(\delta_{t+l} = r_{t+l} + \gamma V(s_{t+l+1}) - V(s_{t+l})\) 是第 \(t+l\) 步的 TD 残差,\(\lambda \in [0,1]\) 是 GAE 的超参数(类似 TD(λ) 中的 λ)。

λ 的两个极端情形

λ = 0

\[ \hat{A}_t = \delta_t = r_t + \gamma V(s_{t+1}) - V(s_t) \]

退化为 TD(0) 单步估计。只有 \(l=0\) 项有贡献。低方差,但依赖 Critic 精度(高偏差)。

λ = 1

\[ \hat{A}_t = \sum_{l=0}^{T-t-1} \gamma^l \delta_{t+l} = G_t - V(s_t) \]

退化为蒙特卡洛估计(用真实轨迹减去基线)。低偏差,但方差高。

🔑 关键洞察:GAE = TD 残差的指数移动平均
未来第 \(l\) 步的 TD 残差 \(\delta_{t+l}\) 的权重为 \((\gamma\lambda)^l\),随 \(l\) 增大指数衰减。实践中常用 \(\lambda \in [0.9, 0.99]\)(如 PPO 中 λ=0.95)。

逆向递推计算(重要!)

实际实现时,GAE 从轨迹末尾往回递推,复杂度 O(T):

\[ \hat{A}_T = 0, \qquad \hat{A}_t = \delta_t + \gamma\lambda \cdot \hat{A}_{t+1} \]

这个递推公式是 GAE 代码实现的核心,非常优雅且高效。

📊 数值例子:手算 GAE 逆向递推
用一条长度 T=4 的轨迹,\(\gamma=0.99, \lambda=0.95\),完整展示递推过程:
t\(r_t\)\(V(s_t)\)\(V(s_{t+1})\)\(\delta_t\)
0+1.05.05.21.0 + 0.99×5.2 − 5.0 = +1.148
1−0.55.24.8−0.5 + 0.99×4.8 − 5.2 = −0.948
2+0.34.85.10.3 + 0.99×5.1 − 4.8 = +0.549
3+0.65.14.90.6 + 0.99×4.9 − 5.1 = +0.351
现在从末尾往前递推 \(\hat{A}_t = \delta_t + \gamma\lambda \cdot \hat{A}_{t+1}\):
# Step 1: t=3 (末尾,无后续)
Â₃ = δ₃ = +0.351

# Step 2: t=2
Â₂ = δ₂ + γλ·Â₃ = 0.549 + 0.99×0.95×0.351 = 0.549 + 0.330 = +0.879

# Step 3: t=1
Â₁ = δ₁ + γλ·Â₂ = −0.948 + 0.9405×0.879 = −0.948 + 0.827 = −0.121

# Step 4: t=0
Â₀ = δ₀ + γλ·Â₁ = 1.148 + 0.9405×(−0.121) = 1.148 − 0.114 = +1.034
对比纯 TD:\(\hat{A}_0^{\text{TD}} = +1.148\),纯 MC:\(\hat{A}_0^{\text{MC}} \approx +0.700\)。GAE 给出 \(+1.034\),在两者之间做了权衡——这就是 \(\lambda\) 的魔力!

6. 交互可视化:GAE 权重如何分布

下方可视化展示了一条长度为 6 的轨迹。拖动 λ 滑块,观察权重 \((\gamma\lambda)^l\) 如何随距离 l 衰减,以及 GAE 是如何从轨迹末尾逆向累积 TD 残差的。

λ = 0.95

7. λ 调参实战指南

\(\lambda\) 是 GAE 最重要的超参数。下面通过具体场景帮你建立直觉:

场景 推荐 λ 原因
Critic 还没训好(训练早期) λ → 1 (0.97–0.99) Critic 不准,少依赖它,多用真实轨迹
Critic 已经很准(训练后期) λ → 0 (0.9–0.95) Critic 可信,多依赖它,减少方差
奖励稀疏(如机器人抵达目标) λ ≈ 0.97 需要看得更远才能拿到奖励信号
奖励密集(如 Atari 游戏) λ ≈ 0.95 每步都有奖励反馈,不需要看太远
环境噪声很大 λ ≈ 0.9 远处信号被噪声淘没,少用
🎯 类比: \(\lambda\) 就像望远镜的变焦。\(\lambda \to 0\) 像放大镜;\(\lambda \to 1\) 像望远镜;\(\lambda = 0.95\) 像一副好眼镜。
🔧 调参小贴士
  • 从 PPO 默认值 \(\lambda=0.95, \gamma=0.99\) 开始。
  • 先调 \(\gamma\),再调 \(\lambda\)。
  • \(\gamma\lambda\) 才是真正的衰减率。
  • 不稳定时降 \(\lambda\),学不到远处奖励时升 \(\lambda\)。

8. 常见坑点与 FAQ

❗ 坑 1:忽略 episode 边界
当 episode 结束时(done=True),必须截断 GAE 累积。否则奖励会跨 episode 泄漏。
# ✅ 正确写法
next_val = values[t+1] * (1.0 - float(dones[t]))
delta = rewards[t] + gamma * next_val - values[t]
last_gae = delta + gamma * lam * last_gae * (1.0 - float(dones[t]))
❗ 坑 2:忘记优势归一化
计算完 GAE 后必须归一化。不归一化会导致训练崩溃。
\[ \hat{A}_t \leftarrow \frac{\hat{A}_t - \text{mean}(\hat{A})}{\text{std}(\hat{A}) + \epsilon} \]
❓ FAQ:\(\gamma\) 和 \(\lambda\) 有什么区别?
\(\gamma\) 是 MDP 的折扣因子(任务时间尺度),\(\lambda\) 是 GAE 的偏差-方差权衡。它们以 \(\gamma\lambda\) 的形式共同作用。
❓ FAQ:GAE 跟 TD(λ) 是一回事吗?
几乎一样!GAE 可以看作 TD(λ) 在策略梯度背景下的表达。数学形式完全相同。
❓ FAQ:为什么不直接用网络预测 Q(s,a)?
可以,但 Q 网络在连续动作空间难训练。V 网络 + GAE 更稳定,这是 PPO 的标准做法。

9. GAE 在 PPO 训练循环中的位置

GAE 是 PPO(Proximal Policy Optimization)的重要组成部分。完整的 PPO 训练循环如下:

采样轨迹
(rollout)
计算 δ_t
计算 GAE
PPO Clip
更新 Actor
MSE Loss
更新 Critic
重复

下面是 GAE 的完整 Python 实现(与 PyTorch 兼容):

import torch

def compute_gae(
    rewards,       # shape [T]   — 每步奖励 r_t
    values,        # shape [T+1] — Critic 输出 V(s_0)...V(s_T)
    dones,         # shape [T]   — episode 是否结束 (bool)
    gamma=0.99,    # 折扣因子
    lam=0.95,      # GAE lambda
):
    T = len(rewards)
    advantages = torch.zeros(T)
    last_gae = 0.0

    for t in reversed(range(T)):
        # 如果 episode 结束,next_value = 0
        next_value = values[t + 1] * (1.0 - float(dones[t]))

        # TD 残差 δ_t
        delta = rewards[t] + gamma * next_value - values[t]

        # GAE 逆向递推: Â_t = δ_t + γλ * Â_{t+1}
        last_gae = delta + gamma * lam * last_gae * (1.0 - float(dones[t]))
        advantages[t] = last_gae

    # 回报 = 优势 + Critic 基线 (用于更新 Critic)
    returns = advantages + values[:-1]
    return advantages, returns


# ---- PPO 训练循环核心 (伪代码) ----
for iteration in range(num_iterations):
    # 1. 采样轨迹
    obs, acts, rews, dones, vals = rollout(actor, critic, env)

    # 2. 计算 GAE 优势
    advantages, returns = compute_gae(rews, vals, dones, gamma, lam)

    # 3. 归一化优势 (重要!减少训练不稳定)
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    # 4. PPO Clip 更新 Actor
    ratio = (new_log_prob - old_log_prob).exp()
    clip_loss = -torch.min(
        ratio * advantages,
        torch.clamp(ratio, 1-eps, 1+eps) * advantages
    ).mean()

    # 5. MSE 更新 Critic
    value_loss = ((critic(obs) - returns) ** 2).mean()
💡 实践小贴士:
  • 优势归一化(减均值除标准差)非常重要,能显著稳定训练。
  • PPO 论文推荐 γ=0.99, λ=0.95,是个很好的起点。
  • λ 接近 1 时对 Critic 精度要求更低(更少依赖 Critic),但需要更长的轨迹。
  • episode 结束时(done=True)必须截断 GAE 累积,否则跨 episode 的信号会污染估计。