使用PyTorch冻结模型参数的方法 |
您所在的位置:网站首页 › 6246cpu参数 › 使用PyTorch冻结模型参数的方法 |
前言
在深度学习领域,经常需要使用其他人已训练好的模型进行改进或微调,这个时候我们会加载已有的预训练模型文件的参数,如果网络结构不变,希望使用新数据微调部分网络参数。这时我们则需要冻结部分参数,禁止其更新。 (1)通过遍历网络结构,设置梯度更新requires_grad = False。 # 冻结network1的全部参数和network2的部分参数 for name, parameter in network1.named_parameters(): parameter.requires_grad = False for name, parameter in network2.named_parameters(): if 'key' in name: parameter.requires_grad = False(2)优化器中过滤filter冻结的参数 optimizer_network2 = torch.optim.Adam(filter(lambda p: p.requires_grad, network2.parameters()), lr=0.005, betas=(0.5, 0.999)) 其他结合加载模型部分参数的情况,优化器需要按如下设置: optimizer_network2 = torch.optim.Adam([{'params': filter(lambda p: p.requires_grad, network2.parameters()), 'initial_lr': 0.0002}], lr=0.005, betas=(0.5, 0.999))[1] csdn - 使用PyTorch加载模型部分参数方法 [2] 知乎 - Pytorch自由载入部分模型参数并冻结 |
CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3 |