最好、最高效、最简洁的,是 “方案一” 。
代码模板:
# 获取要固定部分的state_dict:
pre_state_dict = torch.load(model_path, map_location=torch.device('cpu')
# 导入之(记得strict=False):
model.load_state_dict(pre_state_dict, strict=False)
print('Load model from %s success.' % model_path)
# 固定基本网络:
model = freeze_model(model=model, to_freeze_dict=pre_state_dict)
其中 freeze_model 函数如下:
def freeze_model(model, to_freeze_dict, keep_step=None):
for (name, param) in model.named_parameters():
if name in to_freeze_dict:
param.requires_grad = False
else:
pass
# # 打印当前的固定情况(可忽略):
# freezed_num, pass_num = 0, 0
# for (name, param) in model.named_parameters():
# if param.requires_grad == False:
# freezed_num += 1
# else:
# pass_num += 1
# print('\n Total {} params, miss {} \n'.format(freezed_num + pass_num, pass_num))
return model
Note:
# 获取要固定部分的state_dict:
pre_state_dict = torch.load(model_path, map_location=torch.device('cpu')
改为:
# 获取要固定部分的state_dict:
pre_state_dict = torch.load(model_path, map_location=torch.device('cpu')
pre_state_dict = {k.replace('module.', ''): v for k, v in pre_state_dict.items()}
代码模板:
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1, momentum=0.9, weight_decay=1e-4)
因为:即使对bn设置了 requires_grad = False ,一旦 model.train() ,bn还是会偷偷开启update( model.eval()模式下就又停止update )。(详见【pytorch】bn)
所以:train每个epoch之前都要统一重新定义一下这块,否则容易出问题。
model.eval()
model.stage4_xx.train()
model.pred_xx.train()
pytorch下进行freeze操作,一般需要经过以下四步。
代码模板:
# 获取要固定部分的state_dict:
pre_state_dict = torch.load(model_path, map_location=torch.device('cpu')
# 导入之(记得strict=False):
model.load_state_dict(pre_state_dict, strict=False)
print('Load model from %s success.' % model_path)
# 固定基本网络:
model = freeze_model(model=model, to_freeze_dict=pre_state_dict)
其中 freeze_model 函数如下:
def freeze_model(model, to_freeze_dict, keep_step=None):
for (name, param) in model.named_parameters():
if name in to_freeze_dict:
param.requires_grad = False
else:
pass
# # 打印当前的固定情况(可忽略):
# freezed_num, pass_num = 0, 0
# for (name, param) in model.named_parameters():
# if param.requires_grad == False:
# freezed_num += 1
# else:
# pass_num += 1
# print('\n Total {} params, miss {} \n'.format(freezed_num + pass_num, pass_num))
return model
代码模板:
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1, momentum=0.9, weight_decay=1e-4)
(参考《bn》)即使通过步骤一对bn设置了 requires_grad = False ,一旦 model.train() ,bn还是会偷偷开启update( model.eval()模式下就又停止update )。
所以还需要额外地深入固定bn:
举例:
self.bn = nn.BatchNorm2d(out_channels, eps=1e-5)
修改为:
self.bn = nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.0, track_running_stats=False)
但是 track_running_stats=False 会带来副作用:受波及的每个bn都会在state_dict中丢失三个对应的键值对(每组对应的key都为xx.xx.bn.running_mean、xx.xx.bn.running_var 和 xx.xx.bn.num_batches_tracked)
训练过程中,记得定时check一下被固定部分是否恒定不变:
举例:
self.bn = nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.0, track_running_stats=False)
修改为:
self.bn = nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.0)
此时,之前受波及的每个bn,都会在state_dict中恢复所丢失三个对应的键值对(但是value为空,待填充)。
Note:
为了克服 track_running_stats=False 带来的副作用,最终模型需要依赖 “原始state_dict” 和 “训好的state_dict” 合并。前者为后者补充缺失的value。
# 原始state_dict:
origin_state_dict = torch.load(origin_model_path, map_location=torch.device('cpu'))
# 训好的state_dict:
new_state_dict = torch.load(new_model_path, map_location=torch.device('cpu'))
# 后者从前者中补充缺失的键值对:
final_dict = new_state_dict.copy()
for (key, val) in origin_state_dict.items():
if key not in final_dict:
final_dict[key] = val
# 载入合并好的 state_dict,这时候一定是可以通过 strict=True 的:
model.load_state_dict(final_dict, strict=True)
这时重新再save一遍model,就是可最终直接用的model文件了。