首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >DQN 学不稳?未冻结 target、忘记 detach、终止状态仍 bootstrap 的三连坑

DQN 学不稳?未冻结 target、忘记 detach、终止状态仍 bootstrap 的三连坑

原创
作者头像
九年义务漏网鲨鱼
发布2025-12-18 13:48:56
发布2025-12-18 13:48:56
170
举报

DQN 学不稳?未冻结 target、忘记 detach、终止状态仍 bootstrap 的三连坑

场景:用 DQN 训练 CartPole/LunarLander,本地跑起来“能学”,但回报忽高忽低、训练极不稳定,稍微调大学习率就直接发散。复盘最常见三件事:

  1. 目标网络未分离(用在线网络自己给自己当 target);
  2. 计算 TD 目标时没有 detach(梯度穿到下一时刻 Q 上);
  3. 把真正终止和 time-limit 截断都当成可 bootstrap,或干脆对终止状态也做 bootstrapping。

下面给出最小复现实验与一键修复模板。


Bug 现象

  • 平均回报卡在 50~150,波动大;学习率/epsilon 再怎么调也不稳。
  • TD-error 的均值和方差周期性爆高;近似 Q 值量级持续增大。
  • 打印计算图会发现目标 y 的分支也在反传(说明未 detach),优化器在“追逐移动靶”。

场景复现(CPU 可跑)

保存为 dqn_pitfalls.py,用两行命令对比:

代码语言:python
复制
# 触发三连坑
python dqn_pitfalls.py --bug on  --env CartPole-v1

# 修复版本
python dqn_pitfalls.py --bug off --env CartPole-v1
# dqn_pitfalls.py
import argparse, random, collections, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
import gymnasium as gym
torch.manual_seed(0); np.random.seed(0); random.seed(0)

Transition = collections.namedtuple("T", "s a r s2 terminated truncated")

class QNet(nn.Module):
    def __init__(self, s_dim, a_dim, h=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(s_dim, h), nn.ReLU(),
            nn.Linear(h, h), nn.ReLU(),
            nn.Linear(h, a_dim)
        )
    def forward(self, x): return self.net(x)

class Replay:
    def __init__(self, cap=50000): self.cap, self.buf, self.pos = cap, [None]*cap, 0
    def push(self, *args):
        self.buf[self.pos] = Transition(*args); self.pos = (self.pos+1) % self.cap
    def sample(self, bs):
        batch = random.sample([b for b in self.buf if b is not None], bs)
        t = Transition(*zip(*batch))
        s  = torch.tensor(np.stack(t.s),  dtype=torch.float32)
        a  = torch.tensor(t.a,           dtype=torch.int64).unsqueeze(-1)
        r  = torch.tensor(t.r,           dtype=torch.float32).unsqueeze(-1)
        s2 = torch.tensor(np.stack(t.s2), dtype=torch.float32)
        term = torch.tensor(t.terminated, dtype=torch.float32).unsqueeze(-1)
        trunc = torch.tensor(t.truncated, dtype=torch.float32).unsqueeze(-1)
        return s, a, r, s2, term, trunc
    def __len__(self): return sum(b is not None for b in self.buf)

def epsilon_greedy(q, eps, a_dim):
    if random.random() < eps:
        return random.randrange(a_dim), None
    with torch.no_grad():
        return int(q.argmax().item()), None

def train(args):
    env = gym.make(args.env)
    s_dim, a_dim = env.observation_space.shape[0], env.action_space.n
    online = QNet(s_dim, a_dim)
    target = online if args.bug else QNet(s_dim, a_dim)
    if not args.bug:
        target.load_state_dict(online.state_dict())
        target.eval()

    opt = torch.optim.AdamW(online.parameters(), lr=1e-3)
    buf = Replay(50000)
    gamma = 0.99
    huber = nn.SmoothL1Loss()  # 稳定版损失

    eps_start, eps_end, eps_steps = 1.0, 0.05, 20000
    def eps_at(t): 
        return max(eps_end, eps_start - (eps_start-eps_end)*t/eps_steps)

    s, _ = env.reset(seed=np.random.randint(10000))
    ret, epi = 0.0, 0
    metrics = {"ret":[], "tdm":[], "tds":[], "qabs":[], "eps":[]}

    total_steps = 40000
    warmup, batch = 1000, 64
    target_sync = 1000

    for t in range(1, total_steps+1):
        eps = eps_at(t)
        a, _ = epsilon_greedy(online(torch.tensor(s, dtype=torch.float32)), eps, a_dim)
        s2, r, terminated, truncated, _ = env.step(a)
        buf.push(s, a, np.clip(r, -1, 1), s2, terminated, truncated)  # 奖励裁剪便于稳态
        s, ret = (env.reset()[0] if (terminated or truncated) else s2), (0.0 if (terminated or truncated) else ret+r)
        if terminated or truncated: 
            metrics["ret"].append(ret); epi += 1

        if len(buf) >= warmup:
            s_b, a_b, r_b, s2_b, term_b, trunc_b = buf.sample(batch)
            q = online(s_b).gather(1, a_b)  # [B,1]

            if args.bug:
                # 错误:1) 用在线网络做 target;2) 未 detach;3) 终止也 bootstrap
                q_next = online(s2_b).max(1, keepdim=True).values
                y = r_b + gamma * q_next  # ❌ 三连坑
                loss = F.mse_loss(q, y)
            else:
                with torch.no_grad():                       # 修复:冻结 target 分支
                    # time-limit 截断不算真正终止,只有 terminated=1 才不 bootstrap
                    nonterminal = (1.0 - term_b)            # [B,1]
                    q_next = target(s2_b).max(1, keepdim=True).values
                    y = r_b + gamma * nonterminal * q_next  # ✅ 正确 bootstrap
                loss = huber(q, y)                          # ✅ Huber 更稳

            opt.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(online.parameters(), 10.0)  # ✅ 梯度裁剪避免爆炸
            opt.step()

            if (not args.bug) and (t % target_sync == 0):
                target.load_state_dict(online.state_dict())

            with torch.no_grad():
                td = (q - y).detach()
                metrics["tdm"].append(float(td.mean()))
                metrics["tds"].append(float(td.std()))
                metrics["qabs"].append(float(online(s_b).abs().mean()))
                metrics["eps"].append(eps)

        if t % 1000 == 0 and len(metrics["ret"])>0:
            print(f"[{'BUG' if args.bug else 'FIX'}] step={t:05d} "
                  f"ep={epi:04d} avg_ret={np.mean(metrics['ret'][-10:]):6.2f} "
                  f"td(mean/std)={np.mean(metrics['tdm'][-100:]):+.3f}/{np.mean(metrics['tds'][-100:]):.3f} "
                  f"|Q|≈{np.mean(metrics['qabs'][-100:]):.2f} eps={eps:.3f}")

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--env", default="CartPole-v1")
    ap.add_argument("--bug", choices=["on","off"], default="on")
    args = ap.parse_args()
    train(args)

你会观察到的典型差异

  • bug 版:|Q| 持续增大,TD-error 的均值/方差经常爆,avg_ret 在 50~150 波动;学习率稍大就直接发散。
  • 修复版:|Q| 有界,TD-error 逐步收敛,avg_ret 很快超过 180,稳定到满分附近。

Debug 过程

  1. 检查目标网络与 detach 打印 y.requires_grad 或用 torch.autograd.is_anomaly_enabled() 观察图结构。若 y 可导,说明梯度穿过了 target 分支。
  2. 区分真正终止与 time-limit 截断 Gymnasium 的 step 返回 (terminated, truncated)。只有 terminated=True 才是不做 bootstrapping 的“终止”。truncated=True(时限)应继续用 Q(s') bootstrap。
  3. 监控 Q 值与 TD-error |Q| 长期上扬多半是目标未冻结/未 detach;TD-error 方差异常大时,用 Huber + 裁剪先稳住。
  4. 复现实验 A/B 用上面的脚本在 CPU 跑 40k 步即可看到差异。若你的环境不同(Atari 等),现象会更明显。

代码修改要点(模板)

  1. 分离 target 网络并周期性同步
代码语言:python
复制
target = QNet(s_dim, a_dim).eval()
target.load_state_dict(online.state_dict())
if step % target_sync == 0:
    target.load_state_dict(online.state_dict())
  1. 目标分支 no_grad + 终止掩码
代码语言:python
复制
with torch.no_grad():
    nonterminal = (1.0 - terminated.float().unsqueeze(-1))
    q_next = target(s2).max(1, keepdim=True).values
    y = r + gamma * nonterminal * q_next
loss = F.smooth_l1_loss(q_selected, y)
  1. 稳定性常用护栏
代码语言:python
复制
torch.nn.utils.clip_grad_norm_(online.parameters(), max_norm=10.0)
r = np.clip(r, -1, 1)  # reward clip(经典 DQN 习惯)

监控与护栏

代码语言:python
复制
def assert_target_frozen(y):
    assert (not y.requires_grad), "TD 目标 y 可导,疑似未 detach 或 target 未冻结"

def log_health(step, td, q_abs):
    print(f"[dbg] step={step} td_mean={float(td.mean()):+.3f} td_std={float(td.std()):.3f} |Q|={float(q_abs):.2f}")
  • 每 1k 步打印最近 100 个 batch 的 td_mean/td_std|Q| 均值。
  • td_std 长期高于 2~5 且不降,优先自检上述三点。
  • 评估时用贪心策略(argmax),训练时 epsilon-greedy,日志里始终打印当前 epsilon

常见问答

  • 一定要 Huber 吗 不是必须,但在 TD-error 尾部厚、反馈尺度抖动大的任务上更稳。MSE 也能收敛,但更容易被极端样本拖走。
  • target 更新用硬同步还是软更新 两者皆可。硬同步实现简单(每 N 步复制)。软更新 θ_target ← τ θ_online + (1-τ) θ_target(τ≈0.005)更平滑。
  • Double DQN 要不要上 建议。把 max_a' Q_target(s', a') 改为 Q_target(s', argmax_a Q_online(s', a)) 可以显著减轻过估计偏差。
  • 优先经验回放(PER)是否必要 非必要。先把基础版本稳住再考虑 PER/NoisyNet/Dueling 架构。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • DQN 学不稳?未冻结 target、忘记 detach、终止状态仍 bootstrap 的三连坑
    • Bug 现象
    • 场景复现(CPU 可跑)
    • Debug 过程
    • 代码修改要点(模板)
    • 监控与护栏
    • 常见问答
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档