使用Flask将pytorch模型部署在本地服务器

您所在的位置:网站首页 flask部署pytorch模型 使用Flask将pytorch模型部署在本地服务器

使用Flask将pytorch模型部署在本地服务器

#使用Flask将pytorch模型部署在本地服务器| 来源: 网络整理| 查看: 265

整个项目的思路

训练模型

使用pytorch对resnet18 进行迁移学习,实现对自己的数据进行图像分类。需要将最后一个全连接层中的输出节点数目修改,因为我的数据集中包含有5中图像,所以这里的输出节点数目修改成5 DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # 使用GPU # 加载预训练的resnet18模型 net = torchvision.models.resnet18(pretrained=True) # 冻结原网络参数,仅训练最后新替换的全连接层 for param in net.parameters(): param.requires_grad = False num_ftrs = net.fc.in_features # 原网络最后一层的输入维度 net.fc = nn.Linear(num_ftrs, 5) # 替换新的连接层,输出改为5,预测5个类别 net = net.to(DEVICE) 然后resnet18模型训练好之后,保存训练过程中准确率最高的模型。 # 保存模型参数net.state_dict() torch.save(net.state_dict(), 'net_dict.pt') # 保存完整模型 torch.save(net, 'net.pt') 然后可以随便找一张image,用保存的训练好的模型进行预测 Flask—python服务器 Flask和Django都是web框架。可以将模型发布在服务器(这里使用的是本地服务器)。在对应的URL中实现对模型的调用。 app = flask.Flask(__name__) ... ... @app.route("/predict", methods=["POST"]) def predict(): ... ...

出现下图,说明flask服务开启成功 在这里插入图片描述 向浏览器中输入该网址,然后可以在终端向服务器,以POST方式传过去待识别的图像。并接收从服务器传过来的识别结果。这里的终端暂时使用的是anaconda虚拟环境中的python.exe,来执行.py文件中的代码。后续将Android作为终端。 遇到的问题及解决方法 错误一: 在启动flask服务程序的下段代码中,

preds = F.softmax(model(image), dim=1) results = torch.topk(preds.cpu().data, k=3, dim=1) # Loop over the results and add them to the list of returned predictions for prob, label in zip(results[0][0], results[1][0]): print(label) # tensor(162) label_name = idx2label[label] r = {"label": label_name, "probability": float(prob)} data['predictions'].append(r)

这个地方报错: KeyError: tensor(162) 在这里插入图片描述 错误原因: label_name = idx2label[label] idx2label是一个字典{key0: value0, key1: value1, key2: value2…}, 比如{0: ‘cardboard’, 1: ‘glass’, 2: ‘metal’, 3: ‘paper’, 4: ‘plastic’}。可是输出label,发现label并不是一个数,而是tensor。 所以需要将tensor转换为数值。 修改:将label_name = idx2label[label] ---->label_name = idx2label[int(label)] 错误二: 在用anaconda虚拟环境中的python.exe执行simple_request.py文件时,动态对函数参数赋值传入待预测图像的文件路径时,报错 image = open(image_path, 'rb').read() OSError: [Errno 22] Invalid argument: "'e:/PROJECT/PycharmProjects/pt/test_images/spoon.jpg'" 在这里插入图片描述 修改 将>python E:/PROJECT/PycharmProjects/pt/simple_request.py --file='e:/PROJECT/PycharmProjects/pt/test_images/spoon.jpg中的文件路径改为>python E:/PROJECT/PycharmProjects/pt/simple_request.py --file=e:/PROJECT/PycharmProjects/pt/test_images/spoon.jpg 即去掉图片路径中的单引号,光是这个小小的错误让我头疼了一整天。。 完整代码有时间会传到github上的。。



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3