使用Flask将pytorch模型部署在本地服务器 |
您所在的位置:网站首页 › flask部署pytorch模型 › 使用Flask将pytorch模型部署在本地服务器 |
整个项目的思路 训练模型 使用pytorch对resnet18 进行迁移学习,实现对自己的数据进行图像分类。需要将最后一个全连接层中的输出节点数目修改,因为我的数据集中包含有5中图像,所以这里的输出节点数目修改成5 DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # 使用GPU # 加载预训练的resnet18模型 net = torchvision.models.resnet18(pretrained=True) # 冻结原网络参数,仅训练最后新替换的全连接层 for param in net.parameters(): param.requires_grad = False num_ftrs = net.fc.in_features # 原网络最后一层的输入维度 net.fc = nn.Linear(num_ftrs, 5) # 替换新的连接层,输出改为5,预测5个类别 net = net.to(DEVICE) 然后resnet18模型训练好之后,保存训练过程中准确率最高的模型。 # 保存模型参数net.state_dict() torch.save(net.state_dict(), 'net_dict.pt') # 保存完整模型 torch.save(net, 'net.pt') 然后可以随便找一张image,用保存的训练好的模型进行预测 Flask—python服务器 Flask和Django都是web框架。可以将模型发布在服务器(这里使用的是本地服务器)。在对应的URL中实现对模型的调用。 app = flask.Flask(__name__) ... ... @app.route("/predict", methods=["POST"]) def predict(): ... ...出现下图,说明flask服务开启成功 这个地方报错: KeyError: tensor(162) |
今日新闻 |
推荐新闻 |
CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3 |