在框架升级过程中,经常会出现老版本模型无法调用的问题,其中一个重要的报错经常是:
module.norm1.norm_func.running_mean” and “module.norm1.norm_func.running_var”
for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved
before 0.4.0, this may be expected because InstanceNorm2d does not track running stats
by default since 0.4.0. Please remove these keys from state_dict. If the running stats
are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable
them. See the documentation of InstanceNorm2d for details.
从上面可以看出,模型加载的时候,提醒了老版本的问题。
为了解决这一个问题,可以进行模型中将某些模型进行删除。如下所示:
model_dict = torch.load(args.test_weight_path)
model_dict_clone = model_dict.copy()
for key, value in model_dict_clone.items():
if key.endswith(('running_mean', 'running_var')):
del model_dict[key]
Gnet.load_state_dict(model_dict,False)
而再仔细观察这个问题,发现本质上是一个函数InstanceNorm2d 的关系,因此可以找到该函数,进行修订使其可以支持老版本,即不会出现该问题,解决办法如下:即将track_running_stats=True这个配置新增进去,即不会报错!
norm_layer = functools.partial(
nn.InstanceNorm2d, affine=False, track_running_stats=True)
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。