Qt配置Libtorch并简单测试

news/2024/11/29 2:51:14/

文章目录

  • 软件版本
  • 一、下载Libtorch
  • 二、配置Qt Creator
  • 三、测试项目
  • 参考:

纯小白初次学习深度学习,根据目前所学、所查资料总结该配置文章,如有更好的方法将深度学习模型通过C++部署到实际工程中,希望学习交流。

软件版本

  • Python3.10.9
  • Pytorch2.0.0
  • Qt Creator 5.0.1 (Community)
  • Qt 1.15.2(MSVC 2019, 64bit)

一、下载Libtorch

  • 下载位置(可以按照参考的方法来下载):Pytorch

二、配置Qt Creator

  • Qt Creator配置:
    • 创建.pri文件;
    • 内部写入如下内容:
INCLUDEPATH += C:\soft_install\libtorch-win-shared-with-deps-2.0.0+cpu\libtorch\include
INCLUDEPATH += C:\soft_install\libtorch-win-shared-with-deps-2.0.0+cpu\libtorch\include\torch\csrc\api\includeLIBS += -LC:\soft_install\libtorch-win-shared-with-deps-2.0.0+cpu\libtorch\lib\-lasmjit\-lc10\-lclog\-lcpuinfo\-ldnnl\-lfbgemm\-lfbjni\-lkineto\-llibprotobuf\-llibprotobuf-lite\-llibprotoc\-lpthreadpool\-lpytorch_jni\-ltorch\-ltorch_cpu\-lXNNPACK

三、测试项目

  • 模型来源
    • B站:我是土堆;课程:PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】
  • 创建项目
    • .pro文件中引入Libtorch和Opencv的.pri
DISTFILES += \pri/opencv.pri \pri/libtorch.priinclude(./pri/opencv.pri)
include(./pri/libtorch.pri)
  • c++调用Libtorch接口时头文件

    Libtorch的头文件应该放在cpp或.h中最头的位置(忘记在哪里看到的了);否则需要添加

#undef slots
#include <torch/script.h>
#include <ATen/ATen.h>
#include <torch/torch.h>
#define slots Q_SLOTS
  • .pth 转 .pt
import torch
import torchvision.transforms
from PIL import Image
from torch import nntest = TestSeq()
model = torch.load("../StuPytorch/testSeq_9.pth", map_location=torch.device("cpu"))# 设置一个示例输入
example = torch.rand(1, 3, 32, 32)
# 使用 torch.jit.trace 生成 torch.jit.ScriptModule
model.eval()
traced_script_module = torch.jit.trace(model, example)
# 保存转换后的模型
traced_script_module.save("../StuPytorch/testSeq_9.pt")
  • 简单实现代码(见参考)
    • Pytorch实现
import torch
import torchvision.transforms
from PIL import Image
from torch import nn
from model import *img_path = "../imgs/dog.jpg"
img = Image.open(img_path)
img = img.convert("RGB")
print(img)
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),torchvision.transforms.ToTensor()])
image = transform(img)
print(image.shape)model = torch.load("../StuPytorch/testSeq_29.pth", map_location=torch.device("cpu"))
print(model)image = torch.reshape(image, (1, 3, 32, 32))
print(image)
model.eval()
with torch.no_grad():output = model(image)class_id = {0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer',5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}
print(output)
print(output.argmax(1))
print(class_id[output.argmax(1).item()])---输出
TestSeq((model1): Sequential((0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(4): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(6): Flatten(start_dim=1, end_dim=-1)(7): Linear(in_features=1024, out_features=64, bias=True)(8): Linear(in_features=64, out_features=10, bias=True))
)
tensor([[-1.0712, -2.4485,  1.6878,  0.9057,  1.1824,  1.8202,  0.0850,  0.5854,-2.2806, -2.1651]])
tensor([5])
dog
- C++实现
cv::Mat img = cv::imread("D:\\projects\\PyTorchStud\\imgs\\dog.png");
cv::cvtColor(img, img, cv::COLOR_BGR2RGB);
cv::Mat resizedImg;
cv::resize(img, resizedImg, cv::Size(32, 32));//#停止autograd模块的工作,以起到加速和节省显存的作用
torch::NoGradGuard no_grad;
torch::jit::script::Module module;
module = torch::jit::load("D:\\projects\\PyTorchStud\\StuPytorch\\testSeq_29.pt");
std::cout << "Succeed in loading model" << std::endl;//标签名
std::vector<std::string> labels;
labels.push_back("airplane");
labels.push_back("automobile");
labels.push_back("bird");
labels.push_back("cat");
labels.push_back("deer");
labels.push_back("dog");
labels.push_back("frog");
labels.push_back("horse");
labels.push_back("ship");
labels.push_back("truck");torch::Tensor inputTensor = torch::from_blob(resizedImg.data, {1, 32, 32, 3}, torch::kByte);
//将张量的参数顺序转化为torch输入的格式[B,C,H,W]
inputTensor = inputTensor.permute({0, 3, 1, 2});
inputTensor = inputTensor.toType(torch::kFloat);
clock_t start_time = clock();module.eval();
torch::Tensor outTensor = module.forward({inputTensor}).toTensor();
auto prediction = outTensor.argmax(1);
std::cout << "prediction:" << prediction << std::endl;auto results = outTensor.sort(-1, true);
auto softmaxs = std::get<0>(results)[0].softmax(0);
auto indexs = std::get<1>(results)[0];//输出运行时间
clock_t end_time = clock();
std::cout << "Running time is: "<< static_cast<double>(end_time - start_time) / CLOCKS_PER_SEC * 1000 << "ms"<< std::endl;for (int i = 0; i < 1; ++i) {auto idx = indexs[i].item<int>();std::cout << "    ============= Top-" << i + 1 << " =============" << std::endl;std::cout << "    idx:  " << idx << std::endl;std::cout << "    Label:  " << labels[idx] << std::endl;std::cout << "    With Probability:  " << softmaxs[i].item<float>() * 100.0f << "%"<< std::endl;
}---输出结果
Succeed in loading model
prediction: 5
[ CPULongType{1} ]
Running time is: 118ms============= Top-1 =============idx:  5Label:  dogWith Probability:  100%
  • 结果说明
    • 由于训练的模型问题,有的图片分类错误。

参考:

Libtorch:pytorch分类和语义分割模型在C++工程上的应用


http://www.ppmy.cn/news/39217.html

相关文章

推箱子小游戏

文章目录一、 介绍二、 制作墙壁、地面三、 制作箱子四、 制作终点五、 制作人物移动六、 推箱子关键触发机制七、 终点设置八、 关卡切换设置九、 协程十、 下载一、 介绍 2D推箱子游戏是一种益智类游戏&#xff0c;玩家需要控制角色将箱子推到指定的位置&#xff0c;以完成关…

【WebGIS实例】(7)MapboxGL绘制不同颜色的Symbol图标

前言 在上一篇实例博客中&#xff08;MapboxGL绘制简易气泡图&#xff09;我们绘制了一个简易的单色气泡图&#xff0c;现在需求升级了。我们需要为气泡加载不同的颜色。 而要实现这个效果&#xff0c;其实相当简单&#xff0c;直接利用Mapbox提供的SDF渲染方法。 官网教程参考…

Spring Web MVC 知识点汇总(2)—官方原版

一、异步请求 Spring MVC与Servlet异步请求 处理 有广泛的集成&#xff1a; controller 方法中的 DeferredResult 和 Callable 返回值为单个异步返回值提供了基本支持。controller 可以 流转&#xff08;stream&#xff09; 多个数值&#xff0c;包括 SSE 和 原始数据。contr…

【内网安全】横向移动非约束委派约束委派资源约束委派数据库攻防

文章目录章节点redteam.red 靶场委派攻击分类&#xff1a;关于约束委派与非约束委派横向移动-原理利用-约束委派&非约束委派非约束委派复现配置如何利用&#xff1f;klist purge 与 mimikatz sekurlsa::tickets purge 的区别约束委派(不需要与与域控建立连接)复现配置判断查…

子串判断问题

目录 子串判断 程序设计 程序分析 子串判断 【问题描述】设s、t 为两个字符串,两个字符串分为两行输出,判断t 是否为s 的子串。如果是,输出子串所在位置(第一个字符,字符串的起始位置从0开始),否则输出-1 【输入形式】两行字符串,第一行字符串是s;第二行是字符串t …

闲来无事,写个脚本爬一下快递信息

多线程爬取&#xff1a;可以使用Python中的多线程或异步IO技术来加速爬取速度&#xff0c;提高效率。自动识别快递公司&#xff1a;可以通过输入的快递单号自动识别快递公司&#xff0c;然后根据不同公司的网站结构来爬取相应的信息。数据存储&#xff1a;可以将爬取的数据存储…

2021蓝桥杯真题大写 C语言/C++

题目描述 给定一个只包含大写字母和小写字母的字符串&#xff0c;请将其中所有的小写字母转换成大写字母后将字符串输出。 输入描述 输入一行包含一个字符串。 输出描述 输出转换成大写后的字符串。 输入输出样例 示例 输入 LanQiao 输出 LANQIAO 评测用例规模与约定 对于…

如何自动填充creatTime和updateTime两种字段

1.mysql自带功能 首先是较为常见的&#xff0c;在mysql数据库里设置&#xff0c;但是我的mysql版本不支持该方法&#xff0c;如果尝试了后报错了请直接看方法二 sql语句预览 createTime timestamp not null default CURRENT_TIMESTAMP comment "创建时间", upd…