【pytorch】固定(freeze)住部分网络

您所在的位置:网站首页 freezing和frozen的词性 【pytorch】固定(freeze)住部分网络

【pytorch】固定(freeze)住部分网络

2023-06-27 10:02| 来源: 网络整理| 查看: 265

前言

最好、最高效、最简洁的,是 “方案一” 。

方案一 步骤一、固定基本网络

代码模板:

# 获取要固定部分的state_dict: pre_state_dict = torch.load(model_path, map_location=torch.device('cpu') # 导入之(记得strict=False): model.load_state_dict(pre_state_dict, strict=False) print('Load model from %s success.' % model_path) # 固定基本网络: model = freeze_model(model=model, to_freeze_dict=pre_state_dict) 其中 freeze_model 函数如下: def freeze_model(model, to_freeze_dict, keep_step=None): for (name, param) in model.named_parameters(): if name in to_freeze_dict: param.requires_grad = False else: pass # # 打印当前的固定情况(可忽略): # freezed_num, pass_num = 0, 0 # for (name, param) in model.named_parameters(): # if param.requires_grad == False: # freezed_num += 1 # else: # pass_num += 1 # print('\n Total {} params, miss {} \n'.format(freezed_num + pass_num, pass_num)) return model

Note:

如果预加载模型是在 model = nn.DataParallel(model) 模式下训练出来的分布式模型,那么每个参数名称会默认加上 .module 前缀。相应地,会导致无法对号导入单机模型。此时需要将如下语句: # 获取要固定部分的state_dict: pre_state_dict = torch.load(model_path, map_location=torch.device('cpu') 改为: # 获取要固定部分的state_dict: pre_state_dict = torch.load(model_path, map_location=torch.device('cpu') pre_state_dict = {k.replace('module.', ''): v for k, v in pre_state_dict.items()} 步骤二、让optimizer回避要freeze的参数

代码模板:

optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1, momentum=0.9, weight_decay=1e-4) 步骤三、train时通过.eval()来freeze

因为:即使对bn设置了 requires_grad = False ,一旦 model.train() ,bn还是会偷偷开启update( model.eval()模式下就又停止update )。(详见【pytorch】bn) 所以:train每个epoch之前都要统一重新定义一下这块,否则容易出问题。

model.eval() model.stage4_xx.train() model.pred_xx.train() 方案二

pytorch下进行freeze操作,一般需要经过以下四步。

步骤一、固定基本网络

代码模板:

# 获取要固定部分的state_dict: pre_state_dict = torch.load(model_path, map_location=torch.device('cpu') # 导入之(记得strict=False): model.load_state_dict(pre_state_dict, strict=False) print('Load model from %s success.' % model_path) # 固定基本网络: model = freeze_model(model=model, to_freeze_dict=pre_state_dict) 其中 freeze_model 函数如下: def freeze_model(model, to_freeze_dict, keep_step=None): for (name, param) in model.named_parameters(): if name in to_freeze_dict: param.requires_grad = False else: pass # # 打印当前的固定情况(可忽略): # freezed_num, pass_num = 0, 0 # for (name, param) in model.named_parameters(): # if param.requires_grad == False: # freezed_num += 1 # else: # pass_num += 1 # print('\n Total {} params, miss {} \n'.format(freezed_num + pass_num, pass_num)) return model 步骤二、让optimizer回避要freeze的参数

代码模板:

optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1, momentum=0.9, weight_decay=1e-4) 步骤三、固定bn

即使通过步骤一对bn设置了 requires_grad = False ,一旦 model.train() ,bn还是会偷偷开启update( model.eval()模式下就又停止update )。 所以还需要额外地深入固定bn:

固定 momentum :momentum=0.0掐灭 track_running_stats :track_running_stats=False

举例:

self.bn = nn.BatchNorm2d(out_channels, eps=1e-5)

修改为:

self.bn = nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.0, track_running_stats=False)

但是 track_running_stats=False 会带来副作用:受波及的每个bn都会在state_dict中丢失三个对应的键值对(每组对应的key都为xx.xx.bn.running_mean、xx.xx.bn.running_var 和 xx.xx.bn.num_batches_tracked)

步骤四、正常训练

训练过程中,记得定时check一下被固定部分是否恒定不变:

比如每次eval的时候,顺便check一下被固定部分的预测精度。 步骤五、后处理 4.1 重启track_running_stats

举例:

self.bn = nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.0, track_running_stats=False)

修改为:

self.bn = nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.0)

此时,之前受波及的每个bn,都会在state_dict中恢复所丢失三个对应的键值对(但是value为空,待填充)。

Note:

线上训练虽然用freeze过的网络,但线下测试时,还是要老老实实换回未被freeze的网络。否则结果不仅会对不齐,被freeze和未被freeze的task都会表现更差! 4.2 复原缺失的value

为了克服 track_running_stats=False 带来的副作用,最终模型需要依赖 “原始state_dict” 和 “训好的state_dict” 合并。前者为后者补充缺失的value。

# 原始state_dict: origin_state_dict = torch.load(origin_model_path, map_location=torch.device('cpu')) # 训好的state_dict: new_state_dict = torch.load(new_model_path, map_location=torch.device('cpu')) # 后者从前者中补充缺失的键值对: final_dict = new_state_dict.copy() for (key, val) in origin_state_dict.items(): if key not in final_dict: final_dict[key] = val # 载入合并好的 state_dict,这时候一定是可以通过 strict=True 的: model.load_state_dict(final_dict, strict=True) 这时重新再save一遍model,就是可最终直接用的model文件了。


【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3