libtorch的c++,加载*.pth

news/2025/2/14 1:51:21/

一、转换模型为TorchScript

前提:python只保存了参数,没存结构

要在C++中使用libtorch(PyTorch的C++接口),读取和加载通过torch.save保存的模型(    torch.save(pdn.state_dict()这种方式,只保存了参数,没存结构),需要转换模型为TorchScript。在python下实现。

def get_pdn_small(out_channels=384, padding=False):pad_mult = 1 if padding else 0return nn.Sequential(nn.Conv2d(in_channels=3, out_channels=128, kernel_size=4,padding=3 * pad_mult),nn.ReLU(inplace=True),nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult),nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4,padding=3 * pad_mult),nn.ReLU(inplace=True),nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult),nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3,padding=1 * pad_mult),nn.ReLU(inplace=True),nn.Conv2d(in_channels=256, out_channels=out_channels, kernel_size=4))def get_pdn_medium(out_channels=384, padding=False):pad_mult = 1 if padding else 0return nn.Sequential(nn.Conv2d(in_channels=3, out_channels=256, kernel_size=4,padding=3 * pad_mult),nn.ReLU(inplace=True),nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult),nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4,padding=3 * pad_mult),nn.ReLU(inplace=True),nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3,padding=1 * pad_mult),nn.ReLU(inplace=True),nn.Conv2d(in_channels=512, out_channels=out_channels, kernel_size=4),nn.ReLU(inplace=True),nn.Conv2d(in_channels=out_channels, out_channels=out_channels,kernel_size=1))
import torch# 假设你有一个已训练的模型
model = get_pdn_small()# 加载模型的state_dict
model.load_state_dict(torch.load('teacher_small.pth'))
model.eval()  # 设置模型为评估模式# 将模型转化为TorchScript
scripted_model = torch.jit.script(model)
scripted_model.save('teacher_small.pt')

二、在C++中加载TorchScript模型

在C++中,你可以使用torch::jit::load来加载.pt文件,如下所示:

#include <torch/script.h>  // One-stop header for loading TorchScript models
#include <iostream>
#include <memory>int main() {// 加载TorchScript模型try {// 加载模型std::shared_ptr<torch::jit::Module> model = std::make_shared<torch::jit::Module>(torch::jit::load("teacher_small.pt"));std::cout << "Model loaded successfully!" << std::endl;// 你可以在这里使用模型进行推理,比如输入一个张量// 例如,如果输入是一个3x224x224的图像,你需要创建一个相应的Tensortorch::Tensor input = torch::randn({1, 3, 224, 224});  // 示例输入std::vector<torch::jit::IValue> inputs;inputs.push_back(input);// 执行模型推理at::Tensor output = model->forward(inputs).toTensor();std::cout << "Output tensor: " << output << std::endl;}catch (const c10::Error& e) {std::cerr << "Error loading the model: " << e.what() << std::endl;return -1;}
}


http://www.ppmy.cn/news/1570651.html

相关文章

音频知识基础

音频知识基础 声音属性声音度量人耳特性通道数音频数字化传输接口 声音属性 响度 响度是人耳对声音强弱的主观感受&#xff1b; 主要和声波的振幅相关&#xff0c;同时也和频率有一定关系&#xff1b; 音调 音调是人耳对声音高低的主观感受&#xff1b; 主要与频率相关&#…

物联网实训室解决方案(2025年最新版)

一、专业定位与人才培养体系 &#xff08;一&#xff09;专业战略定位 本专业聚焦物联网产业链关键环节&#xff0c;致力于培养适应未来智能时代需求的复合型技术人才。我们的培养目标是帮助学生掌握物联网全产业链核心技能&#xff0c;包括智能感知、网络通信、数据处理、系…

部署自动化的重要性之骑士资本案例研读

骑士资本&#xff08;Knight Capital&#xff09;是一家证券交易所的金融服务公司&#xff0c;也是美国市场上最大的交易商之一。其在纽约证券交易所的市场份额为 17.3%&#xff0c;在纳斯达克的市场份额为 16.9%。 该公司有一项零售流动性计划&#xff0c;打算用新的 RLP 代码…

鸿蒙harmony 手势密码

1.效果图 2.设置手势页面代码 /*** 手势密码设置页面*/ Entry Component struct SettingGesturePage {/*** PatternLock组件控制器*/private patternLockController: PatternLockController new PatternLockController()/*** 用来保存提示文本信息*/State message: string …

Academy Sports + Outdoors EDI:体育零售巨头的供应链“中枢神经”

Academy Sports Outdoors 是美国领先的体育用品及户外装备零售商&#xff0c;拥有250线下门店及电商平台&#xff0c;年营收超60亿美元。作为全渠道零售商&#xff0c;其供应链面临独特挑战&#xff1a; 海量SKU管理&#xff1a;超50万SKU&#xff08;从健身器材到露营装备&a…

如何设置爬虫的IP代理?

在爬虫开发中&#xff0c;设置IP代理是避免被目标网站封禁、提升爬取效率和保护隐私的重要手段。以下是设置爬虫IP代理的详细方法和注意事项&#xff1a; 一、获取代理IP 免费代理IP&#xff1a; 可以通过一些免费的代理IP网站获取代理IP&#xff0c;但这些IP的稳定性和速度通…

c/c++蓝桥杯经典编程题100道(9)数组排序

数组排序 ->返回c/c蓝桥杯经典编程题100道-目录 目录 数组排序 一、题型解释 二、例题问题描述 三、C语言实现 解法1&#xff1a;冒泡排序&#xff08;难度★&#xff09; 解法2&#xff1a;选择排序&#xff08;难度★&#xff09; 解法3&#xff1a;快速排序&#…

AtCoder Beginner Contest 392(A-G)题解

A-B&#xff1a;略 C&#xff1a;可能题意比较绕&#xff0c;第i个答案就是穿着i这个号码&#xff08;也就是Q[j] i,这个时候j这个位置&#xff09;&#xff0c;看向的那个人的号码&#xff08;也就是P[j]) 代码&#xff1a; void solve() {int n;cin >> n;vi p(n 1…