首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >【Debug日志 | 指标忽高忽低】

【Debug日志 | 指标忽高忽低】

原创
作者头像
九年义务漏网鲨鱼
发布2025-09-28 17:09:49
发布2025-09-28 17:09:49
10100
代码可运行
举报
文章被收录于专栏:tencent cloudtencent cloud
运行总次数:0
代码可运行

指标忽高忽低?评估时忘了切 eval,BatchNorm/Dropout 状态混乱导致的虚假回退(含可复现实验与修复模板)

场景:训练分类模型时,训练集 loss 在降,但验证集准确率时好时坏;同一模型在两次评估间波动 5~15 个百分点。排查后发现评估阶段没有正确切换到 eval 模式,或者切了 eval 又被其他模块改回了 train,导致 BatchNorm 用了“当前批次统计”、Dropout 仍在随机屏蔽,指标严重失真。

❓ Bug 现象

  • 验证集准确率在相邻两次评估间大幅波动,且与 batch 大小高度相关。
  • 将同一批验证数据跑两遍,输出结果不一致(说明仍有随机性参与)。
  • 统计 BatchNorm running_mean/var 变化,评估时还在被更新。
  • 打印 model.training,显示为 True 或在评估函数内部被改动过。

📽️ 场景复现

保存为 eval_mode_bug.py,CPU 可直接运行。

代码语言:javascript
代码运行次数:0
运行
复制
# eval_mode_bug.py
import torch, torch.nn as nn, torch.nn.functional as F
torch.manual_seed(0)

class TinyCNN(nn.Module):
    def __init__(self, n_class=5):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1),
            nn.BatchNorm2d(16), nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.BatchNorm2d(32), nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
        )
        self.fc = nn.Linear(32, n_class)
    def forward(self, x):
        h = self.net(x).flatten(1)
        return self.fc(h)

def make_data(n=2048, n_class=5):
    X = torch.randn(n, 1, 28, 28)
    W = torch.randn(32*1*1, n_class)  # 随机线性可分映射
    y = torch.randint(0, n_class, (n,))
    return X, y

def evaluate_wrong(model, X, y, bs=128):
    # 错误示范:忘记 model.eval(),也没关闭梯度;BN/Dropout 仍在“训练态”
    model.train()  # 故意留下坑
    accs = []
    for i in range(0, len(X), bs):
        logits = model(X[i:i+bs])
        pred = logits.argmax(1)
        accs.append((pred == y[i:i+bs]).float().mean().item())
    return sum(accs)/len(accs)

@torch.no_grad()
def evaluate_right(model, X, y, bs=128):
    was_training = model.training
    model.eval()
    try:
        accs = []
        for i in range(0, len(X), bs):
            logits = model(X[i:i+bs])
            pred = logits.argmax(1)
            accs.append((pred == y[i:i+bs]).float().mean().item())
        return sum(accs)/len(accs)
    finally:
        model.train(was_training)

def main():
    X, y = make_data()
    model = TinyCNN()
    opt = torch.optim.AdamW(model.parameters(), lr=3e-3)

    # 训练几个 step,让 BN 有一些 running 统计
    model.train()
    for step in range(50):
        idx = torch.randint(0, len(X), (128,))
        logits = model(X[idx])
        loss = F.cross_entropy(logits, y[idx])
        opt.zero_grad(set_to_none=True); loss.backward(); opt.step()

    # 错误评估:两次评估结果显著不同(Dropout+BN train 态)
    acc1 = evaluate_wrong(model, X, y)
    acc2 = evaluate_wrong(model, X, y)
    print(f"[WRONG] acc1={acc1:.3f} acc2={acc2:.3f} (应当一致却波动)")

    # 正确评估:多次评估一致,且通常更高更稳定
    acc3 = evaluate_right(model, X, y)
    acc4 = evaluate_right(model, X, y)
    print(f"[RIGHT] acc1={acc3:.3f} acc2={acc4:.3f} (应当稳定一致)")

if __name__ == "__main__":
    main()

你将看到以下两种情况:

  • 错误评估下,两次 acc 差异明显,甚至会随 batch 大小变化而剧烈波动。
  • 正确评估下,acc 基本一致且略高。 原因:评估时仍在训练态,BN 使用当前 batch 的均值方差、Dropout 随机屏蔽单元,导致不确定性和系统性偏差。

Debug 过程

  1. 在评估函数开头打印 model.training 与任意一层的模式
代码语言:javascript
代码运行次数:0
运行
复制
print("model.training:", model.training)
for m in model.modules():
    if isinstance(m, (nn.BatchNorm2d, nn.Dropout)):
        print(type(m).__name__, "training=", m.training); break
  1. 检查是否在某处意外调用了 model.train() 常见于外层循环或回调里无脑 set train,或者复用同一函数既训又评时忘了模式切换。
  2. 统计 BN 的 running_mean/var 是否在评估期间变化 若在评估 loop 前后数值不同,说明你的 BN 仍在更新。
  3. 确认评估关闭梯度 评估应使用 no_grad()/inference_mode(),避免无谓的 autograd 构图与显存消耗,也能降低出错概率。

代码修改

  1. 评估模式的标准模板
代码语言:javascript
代码运行次数:0
运行
复制
from contextlib import contextmanager
import torch

@contextmanager
def evaluating(model):
    was_training = model.training
    model.eval()
    try:
        with torch.inference_mode():
            yield
    finally:
        model.train(was_training)

# 用法
with evaluating(model):
    for x, y in val_loader:
        logits = model(x)
        ...
  1. 训练与评估解耦 不要在同一个函数里混用 train/eval 逻辑;分别写 train_one_epoch 与 evaluate,并明确模式切换。
  2. 精准 BN(可选) 若训练 batch 很小或使用强增广,可在每个 epoch 后用 PreciseBN 重估 running stats,但这仍需在 eval 逻辑之外、且显式控制:
代码语言:javascript
代码运行次数:0
运行
复制
@torch.no_grad()
def precise_bn(model, loader, num_batches=200):
    was_training = model.training
    model.train()
    # 清空统计
    for m in model.modules():
        if isinstance(m, nn.modules.batchnorm._BatchNorm):
            m.running_mean.zero_(); m.running_var.fill_(1); m.num_batches_tracked.zero_()
    it = iter(loader)
    for _ in range(num_batches):
        try: x, _ = next(it)
        except StopIteration: it = iter(loader); x, _ = next(it)
        model(x)
    model.train(was_training)
  1. 多卡场景 DDP 下评估仍应在 eval() + no_grad();若使用 SyncBatchNorm,eval 不再跨卡聚合统计,只读取 running stats,保持一致性。

代码修改

代码语言:javascript
代码运行次数:0
运行
复制
def train_one_epoch(model, loader, optimizer):
    model.train()
    for x, y in loader:
        logits = model(x)
        loss = F.cross_entropy(logits, y)
        optimizer.zero_grad(set_to_none=True); loss.backward(); optimizer.step()

@torch.no_grad()
def evaluate(model, loader):
    was_training = model.training
    model.eval()
    correct = total = 0
    try:
        for x, y in loader:
            logits = model(x)
            pred = logits.argmax(1)
            correct += (pred == y).sum().item()
            total += y.numel()
    finally:
        model.train(was_training)
    return correct / total

Q & A

  • 使用 torch.no_grad 和 torch.inference_mode 的差别 inference_mode 还会跳过版本计数,进一步减少开销;评估期推荐 inference_mode。两者都不改变 train/eval 模式,只是关闭梯度与状态记录。
  • 验证集也要跑强增广吗 除非论文配置要求,一般关闭强增广,仅保留必要的标准化与几何对齐,避免评估不稳定。
  • 精度依赖 batch 大小正常吗 少量依赖正常(数值与 BN 统计),但不应出现大幅漂移;若漂移明显,优先排查 eval 模式与 BN 设置。
  • JIT/torch.compile 会影响 eval 吗 不会自动改模式,但可能缓存图。切换模式前后尽量避免频繁交替编译;将训练与评估分开运行更稳。

结语

评估阶段的一个小小的 eval() 遗漏,就足以把你一整周的曲线和结论带偏。把模式切换包成 evaluating 上下文

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 指标忽高忽低?评估时忘了切 eval,BatchNorm/Dropout 状态混乱导致的虚假回退(含可复现实验与修复模板)
    • ❓ Bug 现象
    • 📽️ 场景复现
    • Debug 过程
    • 代码修改
    • 代码修改
    • Q & A
    • 结语
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档