使用PyTorch冻结模型参数的方法

您所在的位置:网站首页 6246cpu参数 使用PyTorch冻结模型参数的方法

使用PyTorch冻结模型参数的方法

2023-09-16 13:01| 来源: 网络整理| 查看: 265

前言

在深度学习领域,经常需要使用其他人已训练好的模型进行改进或微调,这个时候我们会加载已有的预训练模型文件的参数,如果网络结构不变,希望使用新数据微调部分网络参数。这时我们则需要冻结部分参数,禁止其更新。

在这里插入图片描述

方法

(1)通过遍历网络结构,设置梯度更新requires_grad = False。

# 冻结network1的全部参数和network2的部分参数 for name, parameter in network1.named_parameters(): parameter.requires_grad = False for name, parameter in network2.named_parameters(): if 'key' in name: parameter.requires_grad = False

(2)优化器中过滤filter冻结的参数

optimizer_network2 = torch.optim.Adam(filter(lambda p: p.requires_grad, network2.parameters()), lr=0.005, betas=(0.5, 0.999)) 其他

结合加载模型部分参数的情况,优化器需要按如下设置:

optimizer_network2 = torch.optim.Adam([{'params': filter(lambda p: p.requires_grad, network2.parameters()), 'initial_lr': 0.0002}], lr=0.005, betas=(0.5, 0.999))

在这里插入图片描述

参考资料

[1] csdn - 使用PyTorch加载模型部分参数方法 [2] 知乎 - Pytorch自由载入部分模型参数并冻结



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3