强化学习 · 策略梯度 · 约 10 分钟阅读
GAE
广义优势估计:从 Critic 到偏差-方差权衡
📑 目录
1. Critic 是什么?
在强化学习里,智能体(Agent)要在环境中做决策。Actor-Critic 框架把智能体拆成两个部分:
🎭 Actor(演员)
策略网络 \( \pi_\theta(a \mid s) \),给定状态 \(s\),输出各动作的概率分布。它是真正做决策的那个。
🧠 Critic(评论家)
价值网络 \( V_\phi(s) \),给定状态 \(s\),估计从该状态出发能拿到的期望折扣回报。它不做决策,只负责"打分"。
Critic 估计的价值函数定义为:
即:从状态 \(s\) 出发,按照策略 \(\pi\) 执行动作,未来所有奖励的折扣之和的期望值。\(\gamma \in [0,1)\) 是折扣因子,控制未来奖励相对于即时奖励的重要程度。
为什么需要 Critic?在纯策略梯度(REINFORCE)中,我们用整个轨迹的回报 \(G_t\) 来估计梯度,方差极高(噪声大,训练不稳定)。Critic 提供一个基准值(baseline),大幅降低方差,同时训练更快更稳定。
2. 优势函数 A(s, a)
光有 \(V(s)\) 还不够。我们真正关心的是:某个具体动作 \(a\) 比"平均水平"好多少?这正是优势函数的定义:
其中 \(Q^\pi(s,a)\) 是动作价值函数(Q 函数):在状态 \(s\) 执行动作 \(a\),然后继续按策略 \(\pi\) 行动的期望总回报:
策略梯度定理可以用优势函数表达(这是 Actor-Critic 的核心):
问题来了:\(A^\pi(s,a)\) 无法直接计算(需要知道真实的 \(Q\) 和 \(V\))。我们只能用 Critic 来估计它——这就引出了下面的内容。
假设智能体在一个网格世界中,当前状态 \(s\) 有三个动作可选:
| 动作 a | \(Q^\pi(s, a)\) | \(V^\pi(s)\) | \(A^\pi(s, a) = Q - V\) | 含义 |
|---|---|---|---|---|
| → 右移 | 8.5 | 7.0 | +1.5 | 比平均好,增大概率 |
| ↑ 上移 | 7.0 | 7.0 | 0.0 | 刚好等于平均 |
| ← 左移 | 5.5 | 7.0 | −1.5 | 比平均差,减小概率 |
3. TD 残差:单步优势估计
最简单的优势估计:只走一步,然后用 Critic 估计剩余价值。这就是时序差分(TD)残差:
解读:\(r_t + \gamma V(s_{t+1})\) 是对 \(Q(s_t, a_t)\) 的一步估计(TD 目标),减去 \(V(s_t)\) 就得到了优势估计 \(\hat{A}_t \approx \delta_t\)。
假设一条轨迹的前 3 步,\(\gamma = 0.99\):
| 时刻 t | \(r_t\) | \(V(s_t)\) | \(V(s_{t+1})\) | TD 目标 \(r_t + \gamma V(s_{t+1})\) | \(\delta_t\) |
|---|---|---|---|---|---|
| t=0 | +1.0 | 5.0 | 5.2 | 1.0 + 0.99×5.2 = 6.148 | 6.148 − 5.0 = +1.148 |
| t=1 | −0.5 | 5.2 | 4.8 | −0.5 + 0.99×4.8 = 4.252 | 4.252 − 5.2 = −0.948 |
| t=2 | +0.3 | 4.8 | 5.1 | 0.3 + 0.99×5.1 = 5.349 | 5.349 − 4.8 = +0.549 |
✅ 优点
- 方差低(只依赖一步奖励)
- 计算简单高效
- 可以在线更新
❌ 缺点
- 偏差高(依赖 Critic 的精度)
- 如果 V 不准,估计会很差
- 信号传播慢(短视)
4. 蒙特卡洛回报:N 步优势
走多步再用 Critic 截断,得到 N 步回报:
对应的 N 步优势估计:
当 \(n \to \infty\) 时,就退化为纯蒙特卡洛(不再用 Critic 截断),此时偏差最小,但方差最大。
| 方法 | 步数 n | 偏差 | 方差 | 依赖 Critic? |
|---|---|---|---|---|
| TD(0) | n=1 | 高 | 低 | 是(强) |
| N-step | n=2~10 | 中 | 中 | 是(中) |
| 蒙特卡洛 | n=∞ | 低 | 高 | 否 |
这个表格揭示了一个两难困境:步数越多,偏差越小,但方差越大。GAE 用一个优雅的方式解决这个问题。
沿用上一节的轨迹数据,\(\gamma=0.99\),比较不同步数对 \(t=0\) 的优势估计:
| 方法 | 计算过程 | \(\hat{A}_0\) |
|---|---|---|
| 1-step (TD) | \(r_0 + \gamma V(s_1) - V(s_0) = 1.0 + 0.99 \times 5.2 - 5.0\) | +1.148 |
| 3-step | \(r_0 + \gamma r_1 + \gamma^2 r_2 + \gamma^3 V(s_3) - V(s_0)\) = 1.0 + 0.99×(−0.5) + 0.98×0.3 + 0.97×5.1 − 5.0 |
+0.791 |
| 蒙特卡洛 | \(G_0 - V(s_0)\),假设真实回报 \(G_0 = 5.7\) | +0.700 |
5. GAE:广义优势估计
GAE(Schulman et al., 2016)的核心思想:不选一个固定的步数 n,而是对所有 n 步估计做指数加权平均!
其中 \(\delta_{t+l} = r_{t+l} + \gamma V(s_{t+l+1}) - V(s_{t+l})\) 是第 \(t+l\) 步的 TD 残差,\(\lambda \in [0,1]\) 是 GAE 的超参数(类似 TD(λ) 中的 λ)。
λ 的两个极端情形
λ = 0
退化为 TD(0) 单步估计。只有 \(l=0\) 项有贡献。低方差,但依赖 Critic 精度(高偏差)。
λ = 1
退化为蒙特卡洛估计(用真实轨迹减去基线)。低偏差,但方差高。
未来第 \(l\) 步的 TD 残差 \(\delta_{t+l}\) 的权重为 \((\gamma\lambda)^l\),随 \(l\) 增大指数衰减。实践中常用 \(\lambda \in [0.9, 0.99]\)(如 PPO 中 λ=0.95)。
逆向递推计算(重要!)
实际实现时,GAE 从轨迹末尾往回递推,复杂度 O(T):
这个递推公式是 GAE 代码实现的核心,非常优雅且高效。
用一条长度 T=4 的轨迹,\(\gamma=0.99, \lambda=0.95\),完整展示递推过程:
| t | \(r_t\) | \(V(s_t)\) | \(V(s_{t+1})\) | \(\delta_t\) |
|---|---|---|---|---|
| 0 | +1.0 | 5.0 | 5.2 | 1.0 + 0.99×5.2 − 5.0 = +1.148 |
| 1 | −0.5 | 5.2 | 4.8 | −0.5 + 0.99×4.8 − 5.2 = −0.948 |
| 2 | +0.3 | 4.8 | 5.1 | 0.3 + 0.99×5.1 − 4.8 = +0.549 |
| 3 | +0.6 | 5.1 | 4.9 | 0.6 + 0.99×4.9 − 5.1 = +0.351 |
Â₃ = δ₃ = +0.351
# Step 2: t=2
Â₂ = δ₂ + γλ·Â₃ = 0.549 + 0.99×0.95×0.351 = 0.549 + 0.330 = +0.879
# Step 3: t=1
Â₁ = δ₁ + γλ·Â₂ = −0.948 + 0.9405×0.879 = −0.948 + 0.827 = −0.121
# Step 4: t=0
Â₀ = δ₀ + γλ·Â₁ = 1.148 + 0.9405×(−0.121) = 1.148 − 0.114 = +1.034
6. 交互可视化:GAE 权重如何分布
下方可视化展示了一条长度为 6 的轨迹。拖动 λ 滑块,观察权重 \((\gamma\lambda)^l\) 如何随距离 l 衰减,以及 GAE 是如何从轨迹末尾逆向累积 TD 残差的。
7. λ 调参实战指南
\(\lambda\) 是 GAE 最重要的超参数。下面通过具体场景帮你建立直觉:
| 场景 | 推荐 λ | 原因 |
|---|---|---|
| Critic 还没训好(训练早期) | λ → 1 (0.97–0.99) | Critic 不准,少依赖它,多用真实轨迹 |
| Critic 已经很准(训练后期) | λ → 0 (0.9–0.95) | Critic 可信,多依赖它,减少方差 |
| 奖励稀疏(如机器人抵达目标) | λ ≈ 0.97 | 需要看得更远才能拿到奖励信号 |
| 奖励密集(如 Atari 游戏) | λ ≈ 0.95 | 每步都有奖励反馈,不需要看太远 |
| 环境噪声很大 | λ ≈ 0.9 | 远处信号被噪声淘没,少用 |
- 从 PPO 默认值 \(\lambda=0.95, \gamma=0.99\) 开始。
- 先调 \(\gamma\),再调 \(\lambda\)。
- \(\gamma\lambda\) 才是真正的衰减率。
- 不稳定时降 \(\lambda\),学不到远处奖励时升 \(\lambda\)。
8. 常见坑点与 FAQ
当 episode 结束时(done=True),必须截断 GAE 累积。否则奖励会跨 episode 泄漏。
next_val = values[t+1] * (1.0 - float(dones[t]))
delta = rewards[t] + gamma * next_val - values[t]
last_gae = delta + gamma * lam * last_gae * (1.0 - float(dones[t]))
计算完 GAE 后必须归一化。不归一化会导致训练崩溃。
\(\gamma\) 是 MDP 的折扣因子(任务时间尺度),\(\lambda\) 是 GAE 的偏差-方差权衡。它们以 \(\gamma\lambda\) 的形式共同作用。
几乎一样!GAE 可以看作 TD(λ) 在策略梯度背景下的表达。数学形式完全相同。
可以,但 Q 网络在连续动作空间难训练。V 网络 + GAE 更稳定,这是 PPO 的标准做法。
9. GAE 在 PPO 训练循环中的位置
GAE 是 PPO(Proximal Policy Optimization)的重要组成部分。完整的 PPO 训练循环如下:
(rollout)
计算 GAE
更新 Actor
更新 Critic
下面是 GAE 的完整 Python 实现(与 PyTorch 兼容):
import torch
def compute_gae(
rewards, # shape [T] — 每步奖励 r_t
values, # shape [T+1] — Critic 输出 V(s_0)...V(s_T)
dones, # shape [T] — episode 是否结束 (bool)
gamma=0.99, # 折扣因子
lam=0.95, # GAE lambda
):
T = len(rewards)
advantages = torch.zeros(T)
last_gae = 0.0
for t in reversed(range(T)):
# 如果 episode 结束,next_value = 0
next_value = values[t + 1] * (1.0 - float(dones[t]))
# TD 残差 δ_t
delta = rewards[t] + gamma * next_value - values[t]
# GAE 逆向递推: Â_t = δ_t + γλ * Â_{t+1}
last_gae = delta + gamma * lam * last_gae * (1.0 - float(dones[t]))
advantages[t] = last_gae
# 回报 = 优势 + Critic 基线 (用于更新 Critic)
returns = advantages + values[:-1]
return advantages, returns
# ---- PPO 训练循环核心 (伪代码) ----
for iteration in range(num_iterations):
# 1. 采样轨迹
obs, acts, rews, dones, vals = rollout(actor, critic, env)
# 2. 计算 GAE 优势
advantages, returns = compute_gae(rews, vals, dones, gamma, lam)
# 3. 归一化优势 (重要!减少训练不稳定)
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# 4. PPO Clip 更新 Actor
ratio = (new_log_prob - old_log_prob).exp()
clip_loss = -torch.min(
ratio * advantages,
torch.clamp(ratio, 1-eps, 1+eps) * advantages
).mean()
# 5. MSE 更新 Critic
value_loss = ((critic(obs) - returns) ** 2).mean()
- 优势归一化(减均值除标准差)非常重要,能显著稳定训练。
- PPO 论文推荐 γ=0.99, λ=0.95,是个很好的起点。
- λ 接近 1 时对 Critic 精度要求更低(更少依赖 Critic),但需要更长的轨迹。
- episode 结束时(done=True)必须截断 GAE 累积,否则跨 episode 的信号会污染估计。