Pytorch

您所在的位置:网站首页 训练好的模型参数有哪些 Pytorch

Pytorch

2024-07-03 20:57| 来源: 网络整理| 查看: 265

文章目录 1.前言2.torch.save(保存模型)3.torch.load整个网络4.torch.load网络参数(只提取参数)5.调用三个函数

1.前言

训练好了一个模型, 我们当然想要保存它, 留到下次要用的时候直接提取直接用,下面我将来讲如何存储训练好的模型参数

2.torch.save(保存模型)

首先,先搭建一个神经网络

import torch from torch import nn import matplotlib.pyplot as plt torch.manual_seed(11) # 使每次得到的随机数是固定的。但是如果不加上torch.manual_seed这个函数调用的话,打印出来的随机数每次都不一样 x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # [100] -> [100,1] y = x.pow(2) + 0.5*torch.rand(x.size()) # y的形状与x一样 def make_and_save_model(): network = torch.nn.Sequential( torch.nn.Linear(1, 8), torch.nn.ReLU(), torch.nn.Linear(8, 1) ) optimizer = torch.optim.SGD(network.parameters(), lr=0.3) #优化器 criterion = torch.nn.MSELoss() #损失函数 # 训练 for i in range(200): prediction = network(x) #数据放入模型后得到预测值 loss = criterion(prediction, y) #计算预测值与真实值之间的误差 optimizer.zero_grad() #清空梯度 loss.backward() #误差反向传播 optimizer.step() #更新参数 torch.save(network, 'network.pth') # 保存整个网络 torch.save(network.state_dict(), 'network_params.pth') # 只保存网络中的参数 plt.figure(1, figsize = (10,3)) plt.subplot(131) plt.title('network') plt.scatter(x.data.numpy(), y.data.numpy()) plt.plot(x.data.numpy(), prediction.data.numpy(), 'yo' , lw = 5) plt.pause(1) 3.torch.load整个网络

这种方式将会提取整个神经网络, 网络大的时候可能会比较慢.

def load_whole_model(): network_whole = torch.load('network.pth') prediction = network_whole(x) plt.figure(1, figsize = (10,3)) plt.subplot(132) plt.title('network_whole') plt.scatter(x.data.numpy(), y.data.numpy()) plt.plot(x.data.numpy(), prediction.data.numpy(), 'yo' , lw = 5) plt.pause(1) 4.torch.load网络参数(只提取参数)

这种方式将会提取所有的参数, 然后再放到你的新建网络中

def load_only_params(): network_params = torch.nn.Sequential( torch.nn.Linear(1, 8), torch.nn.ReLU(), torch.nn.Linear(8, 1) ) network_params.load_state_dict(torch.load('network_params.pth')) prediction = network_params(x) plt.figure(1, figsize = (10,3)) plt.subplot(133) plt.title('network_params') plt.scatter(x.data.numpy(), y.data.numpy()) plt.plot(x.data.numpy(), prediction.data.numpy(), 'yo' , lw = 5) 5.调用三个函数

会看到加载后的模型画出的图是一样的,说明模型的参数正确加载了。

make_and_save_model() load_whole_model() load_only_params()

在这里插入图片描述



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3