C++:利用Libtorch调用保存的pytorch模型(.pt)进行前向计算(预测)

您所在的位置:网站首页 pytorch模型转c语言 C++:利用Libtorch调用保存的pytorch模型(.pt)进行前向计算(预测)

C++:利用Libtorch调用保存的pytorch模型(.pt)进行前向计算(预测)

2024-07-11 06:36| 来源: 网络整理| 查看: 265

一、Pytorch框架下生成pt模型(python)

python脚本

# pth转pt import os import torch from PIL import Image from torchvision import transforms from model import AlexNet def main(): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print("using {} device.".format(device)) # create model model = AlexNet(num_classes=5).to(device) img_path = r'D:\项目\人脸识别\deep-learning-for-image-processing-master\data_set\flower_data\rose.jpg' image = Image.open(img_path).convert('RGB') data_transform = transforms.Compose( [transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) img = data_transform(image) img = img.unsqueeze(dim=0) print(img.shape) # load model weights weights_path = "./AlexNet.pth" assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path) testsize = 224 if torch.cuda.is_available(): modelState = torch.load(weights_path, map_location='cuda') model.load_state_dict(modelState, strict=False) model = model.cuda() model = model.eval() # An example input you would normally provide to your model's forward() method. example = torch.rand(1, 3, testsize, testsize) example = example.cuda() traced_script_module = torch.jit.trace(model, example) output = traced_script_module(img.cuda()) print(output.shape) pred = torch.argmax(output, dim=1) print(pred) traced_script_module.save('model_cuda.pt') else: modelState = torch.load(weights_path, map_location='cpu') model.load_state_dict(modelState, strict=False) example = torch.rand(1, 3, testsize, testsize) example = example.cpu() traced_script_module = torch.jit.trace(model, example) output = traced_script_module(img.cpu()) print(output.shape) pred = torch.argmax(output, dim=1) print(pred) traced_script_module.save('AlexNet.pt') if __name__ == '__main__': main() 二、C++环境下调用pt模型 1、关键代码

引入头文件

#include #include "torch/torch.h"

引入命名空间

using torch::jit::script::Module;

创建Libtorch模型对象

Module* module = new Module();

加载pt模型

*module = torch::jit::load("D:/项目/人脸识别/deep-learning-for-image-processing-master/pytorch_classification/Test2_alexnet/AlexNet.pt");

前向计算

at::Tensor output = module.forward({ tensor_image }).toTensor();

完整代码:

#include #include #include "torch/torch.h" #include #include "opencv2/core.hpp" #include "opencv2/imgproc.hpp" #include "opencv2/highgui.hpp" #include "opencv2/imgcodecs.hpp" #include #include #include #include #include std::string classList[5] = { "daisy", "dandelion", "rose", "sunflower", "tulip" }; std::string image_path = "D:/项目/人脸识别/deep-learning-for-image-processing-master/data_set/flower_data/tulip.jpg"; int main(int argc, const char* argv[]) { // Deserialize the ScriptModule from a file using torch::jit::load(). //std::shared_ptr module = torch::jit::load("../../model_resnet_jit.pt"); using torch::jit::script::Module; Module module = torch::jit::load("D:/项目/人脸识别/deep-learning-for-image-processing-master/pytorch_classification/Test2_alexnet/AlexNet.pt"); std::cout


【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3