PyTorch中的模型保存:一键保存、两种选择/保存整个模型和保存模型参数

您所在的位置:网站首页 保存和另保存的区别 PyTorch中的模型保存:一键保存、两种选择/保存整个模型和保存模型参数

PyTorch中的模型保存:一键保存、两种选择/保存整个模型和保存模型参数

2024-06-16 21:05| 来源: 网络整理| 查看: 265

探索PyTorch中的模型保存:一键保存、两种选择 目录 一键保存整个模型:保留全貌只保存模型参数:轻装上阵转换的奇妙之处 保存整个模型转换为保存模型参数保存模型参数转换为保存整个模型 结语 一键保存整个模型:保留全貌

当我们使用一键保存功能时,PyTorch会把整个模型连同它的结构和参数一起保存下来。这意味着我们可以完整地保存模型的状态,随时随地加载它并开始预测或继续训练。

python import torch import torchvision.models as models # 创建模型并保存整个模型 model = models.resnet18(pretrained=True) torch.save(model, 'whole_model.pth') # 加载整个模型 loaded_model = torch.load('whole_model.pth')

通过这种方式,我们一举保存了模型的全貌,文件通常以.pth或.pt为后缀。

只保存模型参数:轻装上阵

与保存整个模型相比,有时我们只需要保存模型的参数而不是结构。这种方式会生成更小的文件,更适用于共享参数或迁移学习等场景。

import torch import torchvision.models as models # 创建模型并加载预训练参数 model = models.resnet18() model.load_state_dict(torch.load('model_params.pth')) # 保存模型参数 torch.save(model.state_dict(), 'model_params.pth')

通过这种方式,我们轻装上阵,只携带了模型的参数而不是整个结构。

转换的奇妙之处

有时候,我们需要在保存整个模型和保存模型参数之间自由转换。这时候,我们只需一行代码就可以实现。

保存整个模型转换为保存模型参数: import torch # 加载整个模型 loaded_model = torch.load('whole_model.pth') # 保存模型参数 torch.save(loaded_model.state_dict(), 'model_params.pth') 保存整个模型转换为保存模型参数: import torch import torchvision.models as models # 创建模型并加载模型参数 model = models.resnet18() model.load_state_dict(torch.load('model_params.pth')) # 保存整个模型 torch.save(model, 'whole_model_from_params.pth')

通过这种转换的方式,我们可以随心所欲地在保存模型整体结构和仅保存参数之间切换,让模型保存变得更加灵活便捷。



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3