场景:训练分类模型时,训练集 loss 在降,但验证集准确率时好时坏;同一模型在两次评估间波动 5~15 个百分点。排查后发现评估阶段没有正确切换到 eval 模式,或者切了 eval 又被其他模块改回了 train,导致 BatchNorm 用了“当前批次统计”、Dropout 仍在随机屏蔽,指标严重失真。
保存为 eval_mode_bug.py,CPU 可直接运行。
# eval_mode_bug.py
import torch, torch.nn as nn, torch.nn.functional as F
torch.manual_seed(0)
class TinyCNN(nn.Module):
def __init__(self, n_class=5):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(1, 16, 3, padding=1),
nn.BatchNorm2d(16), nn.ReLU(),
nn.Dropout(p=0.5),
nn.Conv2d(16, 32, 3, padding=1),
nn.BatchNorm2d(32), nn.ReLU(),
nn.AdaptiveAvgPool2d(1),
)
self.fc = nn.Linear(32, n_class)
def forward(self, x):
h = self.net(x).flatten(1)
return self.fc(h)
def make_data(n=2048, n_class=5):
X = torch.randn(n, 1, 28, 28)
W = torch.randn(32*1*1, n_class) # 随机线性可分映射
y = torch.randint(0, n_class, (n,))
return X, y
def evaluate_wrong(model, X, y, bs=128):
# 错误示范:忘记 model.eval(),也没关闭梯度;BN/Dropout 仍在“训练态”
model.train() # 故意留下坑
accs = []
for i in range(0, len(X), bs):
logits = model(X[i:i+bs])
pred = logits.argmax(1)
accs.append((pred == y[i:i+bs]).float().mean().item())
return sum(accs)/len(accs)
@torch.no_grad()
def evaluate_right(model, X, y, bs=128):
was_training = model.training
model.eval()
try:
accs = []
for i in range(0, len(X), bs):
logits = model(X[i:i+bs])
pred = logits.argmax(1)
accs.append((pred == y[i:i+bs]).float().mean().item())
return sum(accs)/len(accs)
finally:
model.train(was_training)
def main():
X, y = make_data()
model = TinyCNN()
opt = torch.optim.AdamW(model.parameters(), lr=3e-3)
# 训练几个 step,让 BN 有一些 running 统计
model.train()
for step in range(50):
idx = torch.randint(0, len(X), (128,))
logits = model(X[idx])
loss = F.cross_entropy(logits, y[idx])
opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
# 错误评估:两次评估结果显著不同(Dropout+BN train 态)
acc1 = evaluate_wrong(model, X, y)
acc2 = evaluate_wrong(model, X, y)
print(f"[WRONG] acc1={acc1:.3f} acc2={acc2:.3f} (应当一致却波动)")
# 正确评估:多次评估一致,且通常更高更稳定
acc3 = evaluate_right(model, X, y)
acc4 = evaluate_right(model, X, y)
print(f"[RIGHT] acc1={acc3:.3f} acc2={acc4:.3f} (应当稳定一致)")
if __name__ == "__main__":
main()
你将看到以下两种情况:
print("model.training:", model.training)
for m in model.modules():
if isinstance(m, (nn.BatchNorm2d, nn.Dropout)):
print(type(m).__name__, "training=", m.training); break
from contextlib import contextmanager
import torch
@contextmanager
def evaluating(model):
was_training = model.training
model.eval()
try:
with torch.inference_mode():
yield
finally:
model.train(was_training)
# 用法
with evaluating(model):
for x, y in val_loader:
logits = model(x)
...
@torch.no_grad()
def precise_bn(model, loader, num_batches=200):
was_training = model.training
model.train()
# 清空统计
for m in model.modules():
if isinstance(m, nn.modules.batchnorm._BatchNorm):
m.running_mean.zero_(); m.running_var.fill_(1); m.num_batches_tracked.zero_()
it = iter(loader)
for _ in range(num_batches):
try: x, _ = next(it)
except StopIteration: it = iter(loader); x, _ = next(it)
model(x)
model.train(was_training)
def train_one_epoch(model, loader, optimizer):
model.train()
for x, y in loader:
logits = model(x)
loss = F.cross_entropy(logits, y)
optimizer.zero_grad(set_to_none=True); loss.backward(); optimizer.step()
@torch.no_grad()
def evaluate(model, loader):
was_training = model.training
model.eval()
correct = total = 0
try:
for x, y in loader:
logits = model(x)
pred = logits.argmax(1)
correct += (pred == y).sum().item()
total += y.numel()
finally:
model.train(was_training)
return correct / total
评估阶段的一个小小的 eval() 遗漏,就足以把你一整周的曲线和结论带偏。把模式切换包成 evaluating 上下文
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。