
场景:用 DQN 训练 CartPole/LunarLander,本地跑起来“能学”,但回报忽高忽低、训练极不稳定,稍微调大学习率就直接发散。复盘最常见三件事:
detach(梯度穿到下一时刻 Q 上);下面给出最小复现实验与一键修复模板。
detach),优化器在“追逐移动靶”。保存为 dqn_pitfalls.py,用两行命令对比:
# 触发三连坑
python dqn_pitfalls.py --bug on --env CartPole-v1
# 修复版本
python dqn_pitfalls.py --bug off --env CartPole-v1
# dqn_pitfalls.py
import argparse, random, collections, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
import gymnasium as gym
torch.manual_seed(0); np.random.seed(0); random.seed(0)
Transition = collections.namedtuple("T", "s a r s2 terminated truncated")
class QNet(nn.Module):
def __init__(self, s_dim, a_dim, h=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(s_dim, h), nn.ReLU(),
nn.Linear(h, h), nn.ReLU(),
nn.Linear(h, a_dim)
)
def forward(self, x): return self.net(x)
class Replay:
def __init__(self, cap=50000): self.cap, self.buf, self.pos = cap, [None]*cap, 0
def push(self, *args):
self.buf[self.pos] = Transition(*args); self.pos = (self.pos+1) % self.cap
def sample(self, bs):
batch = random.sample([b for b in self.buf if b is not None], bs)
t = Transition(*zip(*batch))
s = torch.tensor(np.stack(t.s), dtype=torch.float32)
a = torch.tensor(t.a, dtype=torch.int64).unsqueeze(-1)
r = torch.tensor(t.r, dtype=torch.float32).unsqueeze(-1)
s2 = torch.tensor(np.stack(t.s2), dtype=torch.float32)
term = torch.tensor(t.terminated, dtype=torch.float32).unsqueeze(-1)
trunc = torch.tensor(t.truncated, dtype=torch.float32).unsqueeze(-1)
return s, a, r, s2, term, trunc
def __len__(self): return sum(b is not None for b in self.buf)
def epsilon_greedy(q, eps, a_dim):
if random.random() < eps:
return random.randrange(a_dim), None
with torch.no_grad():
return int(q.argmax().item()), None
def train(args):
env = gym.make(args.env)
s_dim, a_dim = env.observation_space.shape[0], env.action_space.n
online = QNet(s_dim, a_dim)
target = online if args.bug else QNet(s_dim, a_dim)
if not args.bug:
target.load_state_dict(online.state_dict())
target.eval()
opt = torch.optim.AdamW(online.parameters(), lr=1e-3)
buf = Replay(50000)
gamma = 0.99
huber = nn.SmoothL1Loss() # 稳定版损失
eps_start, eps_end, eps_steps = 1.0, 0.05, 20000
def eps_at(t):
return max(eps_end, eps_start - (eps_start-eps_end)*t/eps_steps)
s, _ = env.reset(seed=np.random.randint(10000))
ret, epi = 0.0, 0
metrics = {"ret":[], "tdm":[], "tds":[], "qabs":[], "eps":[]}
total_steps = 40000
warmup, batch = 1000, 64
target_sync = 1000
for t in range(1, total_steps+1):
eps = eps_at(t)
a, _ = epsilon_greedy(online(torch.tensor(s, dtype=torch.float32)), eps, a_dim)
s2, r, terminated, truncated, _ = env.step(a)
buf.push(s, a, np.clip(r, -1, 1), s2, terminated, truncated) # 奖励裁剪便于稳态
s, ret = (env.reset()[0] if (terminated or truncated) else s2), (0.0 if (terminated or truncated) else ret+r)
if terminated or truncated:
metrics["ret"].append(ret); epi += 1
if len(buf) >= warmup:
s_b, a_b, r_b, s2_b, term_b, trunc_b = buf.sample(batch)
q = online(s_b).gather(1, a_b) # [B,1]
if args.bug:
# 错误:1) 用在线网络做 target;2) 未 detach;3) 终止也 bootstrap
q_next = online(s2_b).max(1, keepdim=True).values
y = r_b + gamma * q_next # ❌ 三连坑
loss = F.mse_loss(q, y)
else:
with torch.no_grad(): # 修复:冻结 target 分支
# time-limit 截断不算真正终止,只有 terminated=1 才不 bootstrap
nonterminal = (1.0 - term_b) # [B,1]
q_next = target(s2_b).max(1, keepdim=True).values
y = r_b + gamma * nonterminal * q_next # ✅ 正确 bootstrap
loss = huber(q, y) # ✅ Huber 更稳
opt.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(online.parameters(), 10.0) # ✅ 梯度裁剪避免爆炸
opt.step()
if (not args.bug) and (t % target_sync == 0):
target.load_state_dict(online.state_dict())
with torch.no_grad():
td = (q - y).detach()
metrics["tdm"].append(float(td.mean()))
metrics["tds"].append(float(td.std()))
metrics["qabs"].append(float(online(s_b).abs().mean()))
metrics["eps"].append(eps)
if t % 1000 == 0 and len(metrics["ret"])>0:
print(f"[{'BUG' if args.bug else 'FIX'}] step={t:05d} "
f"ep={epi:04d} avg_ret={np.mean(metrics['ret'][-10:]):6.2f} "
f"td(mean/std)={np.mean(metrics['tdm'][-100:]):+.3f}/{np.mean(metrics['tds'][-100:]):.3f} "
f"|Q|≈{np.mean(metrics['qabs'][-100:]):.2f} eps={eps:.3f}")
if __name__ == "__main__":
ap = argparse.ArgumentParser()
ap.add_argument("--env", default="CartPole-v1")
ap.add_argument("--bug", choices=["on","off"], default="on")
args = ap.parse_args()
train(args)你会观察到的典型差异
|Q| 持续增大,TD-error 的均值/方差经常爆,avg_ret 在 50~150 波动;学习率稍大就直接发散。|Q| 有界,TD-error 逐步收敛,avg_ret 很快超过 180,稳定到满分附近。detach
打印 y.requires_grad 或用 torch.autograd.is_anomaly_enabled() 观察图结构。若 y 可导,说明梯度穿过了 target 分支。step 返回 (terminated, truncated)。只有 terminated=True 才是不做 bootstrapping 的“终止”。truncated=True(时限)应继续用 Q(s') bootstrap。|Q| 长期上扬多半是目标未冻结/未 detach;TD-error 方差异常大时,用 Huber + 裁剪先稳住。target = QNet(s_dim, a_dim).eval()
target.load_state_dict(online.state_dict())
if step % target_sync == 0:
target.load_state_dict(online.state_dict())no_grad + 终止掩码with torch.no_grad():
nonterminal = (1.0 - terminated.float().unsqueeze(-1))
q_next = target(s2).max(1, keepdim=True).values
y = r + gamma * nonterminal * q_next
loss = F.smooth_l1_loss(q_selected, y)torch.nn.utils.clip_grad_norm_(online.parameters(), max_norm=10.0)
r = np.clip(r, -1, 1) # reward clip(经典 DQN 习惯)def assert_target_frozen(y):
assert (not y.requires_grad), "TD 目标 y 可导,疑似未 detach 或 target 未冻结"
def log_health(step, td, q_abs):
print(f"[dbg] step={step} td_mean={float(td.mean()):+.3f} td_std={float(td.std()):.3f} |Q|={float(q_abs):.2f}")td_mean/td_std 与 |Q| 均值。td_std 长期高于 2~5 且不降,优先自检上述三点。argmax),训练时 epsilon-greedy,日志里始终打印当前 epsilon。θ_target ← τ θ_online + (1-τ) θ_target(τ≈0.005)更平滑。max_a' Q_target(s', a') 改为 Q_target(s', argmax_a Q_online(s', a)) 可以显著减轻过估计偏差。原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。