Pytorch实现LSTM预测模型并使用C++相应的ONNX模型推理

news/2025/1/16 0:04:54/

Pytorch实现RNN模型

代码

import torch
import torch.nn as nnclass LSTM(nn.Module):def __init__(self, input_size, output_size, out_channels, num_layers, device):super(LSTM, self).__init__()self.device = deviceself.input_size = input_sizeself.hidden_size = input_sizeself.num_layers = num_layersself.output_size = output_sizeself.lstm = nn.LSTM(input_size=self.input_size,hidden_size=self.hidden_size,num_layers=self.num_layers,batch_first=True)self.out_channels = out_channelsself.fc = nn.Linear(self.hidden_size, self.output_size)def forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(self.device)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(self.device)out, _ = self.lstm(x, (h0, c0))if self.out_channels == 1:out = out[:, -1, :]return outreturn outbatch_size = 20
input_size = 10
output_size = 10
num_layers = 2
out_channels = 1model = LSTM(input_size, output_size, out_channels, num_layers, "cpu")
model.eval() input_names = ["input"]
output_names  = ["output"]x = torch.randn((batch_size, input_size, output_size))
print(x.shape)
y = model(x)
print(y.shape)torch.onnx.export(model, x, 'LSTM.onnx', verbose=True, input_names=input_names, output_names=output_names,dynamic_axes={'input':[0], 'output':[0]} )import onnx
model = onnx.load("LSTM.onnx")
print("load model done.")
onnx.checker.check_model(model)
print(onnx.helper.printable_graph(model.graph))
print("check model done.")

运行结果

torch.Size([20, 10, 10])
torch.Size([20, 10])
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/utils.py:2041: UserWarning: No names were found for specified dynamic axes of provided input.Automatically generated names will be applied to each dynamic axes of input input"No names were found for specified dynamic axes of provided input."
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/utils.py:2041: UserWarning: No names were found for specified dynamic axes of provided input.Automatically generated names will be applied to each dynamic axes of input output"No names were found for specified dynamic axes of provided input."
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/symbolic_opset9.py:4322: UserWarning: Exporting a model to ONNX with a batch_size other than 1, with a variable length with LSTM can cause an error when running the ONNX model with a different batch size. Make sure to save the model with a batch size of 1, or define the initial states (h0/c0) as inputs of the model. + "or define the initial states (h0/c0) as inputs of the model. "
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/_internal/jit_utils.py:258: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)_C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/utils.py:688: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)graph, params_dict, GLOBALS.export_onnx_opset_version
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/utils.py:1179: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)graph, params_dict, GLOBALS.export_onnx_opset_version
Exported graph: graph(%input : Float(*, 10, 10, strides=[100, 10, 1], requires_grad=0, device=cpu),%onnx::LSTM_193 : Float(1, 40, 10, strides=[400, 10, 1], requires_grad=0, device=cpu),%onnx::LSTM_194 : Float(1, 40, 10, strides=[400, 10, 1], requires_grad=0, device=cpu),%onnx::LSTM_195 : Float(1, 80, strides=[80, 1], requires_grad=0, device=cpu),%onnx::LSTM_213 : Float(1, 40, 10, strides=[400, 10, 1], requires_grad=0, device=cpu),%onnx::LSTM_214 : Float(1, 40, 10, strides=[400, 10, 1], requires_grad=0, device=cpu),%onnx::LSTM_215 : Float(1, 80, strides=[80, 1], requires_grad=0, device=cpu)):%/Shape_output_0 : Long(3, strides=[1], device=cpu) = onnx::Shape[onnx_name="/Shape"](%input), scope: __main__.LSTM:: # /zengli/20230320/ao/test/test_onnx_lstm.py:23:0%/Constant_output_0 : Long(device=cpu) = onnx::Constant[value={0}, onnx_name="/Constant"](), scope: __main__.LSTM:: # /zengli/20230320/ao/test/test_onnx_lstm.py:23:0%/Gather_output_0 : Long(device=cpu) = onnx::Gather[axis=0, onnx_name="/Gather"](%/Shape_output_0, %/Constant_output_0), scope: __main__.LSTM:: # /zengli/20230320/ao/test/test_onnx_lstm.py:23:0%/Constant_1_output_0 : Long(1, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value={2}, onnx_name="/Constant_1"](), scope: __main__.LSTM::%onnx::Unsqueeze_18 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}]()%/Unsqueeze_output_0 : Long(1, strides=[1], device=cpu) = onnx::Unsqueeze[onnx_name="/Unsqueeze"](%/Gather_output_0, %onnx::Unsqueeze_18), scope: __main__.LSTM::%/Constant_2_output_0 : Long(1, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value={10}, onnx_name="/Constant_2"](), scope: __main__.LSTM::%/Concat_output_0 : Long(3, strides=[1], device=cpu) = onnx::Concat[axis=0, onnx_name="/Concat"](%/Constant_1_output_0, %/Unsqueeze_output_0, %/Constant_2_output_0), scope: __main__.LSTM:: # /zengli/20230320/ao/test/test_onnx_lstm.py:23:0%/ConstantOfShape_output_0 : Float(*, *, *, strides=[200, 10, 1], requires_grad=0, device=cpu) = onnx::ConstantOfShape[value={0}, onnx_name="/ConstantOfShape"](%/Concat_output_0), scope: __main__.LSTM:: # /zengli/20230320/ao/test/test_onnx_lstm.py:23:0%/Cast_output_0 : Float(*, *, *, strides=[200, 10, 1], requires_grad=0, device=cpu) = onnx::Cast[to=1, onnx_name="/Cast"](%/ConstantOfShape_output_0), scope: __main__.LSTM:: # /zengli/20230320/ao/test/test_onnx_lstm.py:23:0%/lstm/Transpose_output_0 : Float(10, *, 10, device=cpu) = onnx::Transpose[perm=[1, 0, 2], onnx_name="/lstm/Transpose"](%input), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%onnx::LSTM_26 : Tensor? = prim::Constant(), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Constant_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/lstm/Constant"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Constant_1_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/lstm/Constant_1"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Constant_2_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/lstm/Constant_2"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Slice_output_0 : Float(*, *, *, device=cpu) = onnx::Slice[onnx_name="/lstm/Slice"](%/Cast_output_0, %/lstm/Constant_1_output_0, %/lstm/Constant_2_output_0, %/lstm/Constant_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Constant_3_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/lstm/Constant_3"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Constant_4_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/lstm/Constant_4"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Constant_5_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/lstm/Constant_5"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Slice_1_output_0 : Float(*, *, *, device=cpu) = onnx::Slice[onnx_name="/lstm/Slice_1"](%/Cast_output_0, %/lstm/Constant_4_output_0, %/lstm/Constant_5_output_0, %/lstm/Constant_3_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/LSTM_output_0 : Float(10, 1, *, 10, device=cpu), %/lstm/LSTM_output_1 : Float(1, *, 10, device=cpu), %/lstm/LSTM_output_2 : Float(1, *, 10, device=cpu) = onnx::LSTM[hidden_size=10, onnx_name="/lstm/LSTM"](%/lstm/Transpose_output_0, %onnx::LSTM_193, %onnx::LSTM_194, %onnx::LSTM_195, %onnx::LSTM_26, %/lstm/Slice_output_0, %/lstm/Slice_1_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Constant_6_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/lstm/Constant_6"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Squeeze_output_0 : Float(10, *, 10, device=cpu) = onnx::Squeeze[onnx_name="/lstm/Squeeze"](%/lstm/LSTM_output_0, %/lstm/Constant_6_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Constant_7_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/lstm/Constant_7"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Constant_8_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/lstm/Constant_8"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Constant_9_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={2}, onnx_name="/lstm/Constant_9"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Slice_2_output_0 : Float(*, *, *, device=cpu) = onnx::Slice[onnx_name="/lstm/Slice_2"](%/Cast_output_0, %/lstm/Constant_8_output_0, %/lstm/Constant_9_output_0, %/lstm/Constant_7_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Constant_10_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/lstm/Constant_10"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Constant_11_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/lstm/Constant_11"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Constant_12_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={2}, onnx_name="/lstm/Constant_12"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Slice_3_output_0 : Float(*, *, *, device=cpu) = onnx::Slice[onnx_name="/lstm/Slice_3"](%/Cast_output_0, %/lstm/Constant_11_output_0, %/lstm/Constant_12_output_0, %/lstm/Constant_10_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/LSTM_1_output_0 : Float(10, 1, *, 10, device=cpu), %/lstm/LSTM_1_output_1 : Float(1, *, 10, device=cpu), %/lstm/LSTM_1_output_2 : Float(1, *, 10, device=cpu) = onnx::LSTM[hidden_size=10, onnx_name="/lstm/LSTM_1"](%/lstm/Squeeze_output_0, %onnx::LSTM_213, %onnx::LSTM_214, %onnx::LSTM_215, %onnx::LSTM_26, %/lstm/Slice_2_output_0, %/lstm/Slice_3_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Constant_13_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/lstm/Constant_13"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Squeeze_1_output_0 : Float(10, *, 10, device=cpu) = onnx::Squeeze[onnx_name="/lstm/Squeeze_1"](%/lstm/LSTM_1_output_0, %/lstm/Constant_13_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Transpose_1_output_0 : Float(*, 10, 10, strides=[10, 200, 1], requires_grad=1, device=cpu) = onnx::Transpose[perm=[1, 0, 2], onnx_name="/lstm/Transpose_1"](%/lstm/Squeeze_1_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/Constant_3_output_0 : Long(device=cpu) = onnx::Constant[value={-1}, onnx_name="/Constant_3"](), scope: __main__.LSTM::%output : Float(*, 10, strides=[10, 1], requires_grad=1, device=cpu) = onnx::Gather[axis=1, onnx_name="/Gather_1"](%/lstm/Transpose_1_output_0, %/Constant_3_output_0), scope: __main__.LSTM:: # /zengli/20230320/ao/test/test_onnx_lstm.py:29:0return (%output)load model done.
graph torch_jit (%input[FLOAT, input_dynamic_axes_1x10x10]
) initializers (%onnx::LSTM_193[FLOAT, 1x40x10]%onnx::LSTM_194[FLOAT, 1x40x10]%onnx::LSTM_195[FLOAT, 1x80]%onnx::LSTM_213[FLOAT, 1x40x10]%onnx::LSTM_214[FLOAT, 1x40x10]%onnx::LSTM_215[FLOAT, 1x80]
) {%/Shape_output_0 = Shape(%input)%/Constant_output_0 = Constant[value = <Scalar Tensor []>]()%/Gather_output_0 = Gather[axis = 0](%/Shape_output_0, %/Constant_output_0)%/Constant_1_output_0 = Constant[value = <Tensor>]()%onnx::Unsqueeze_18 = Constant[value = <Tensor>]()%/Unsqueeze_output_0 = Unsqueeze(%/Gather_output_0, %onnx::Unsqueeze_18)%/Constant_2_output_0 = Constant[value = <Tensor>]()%/Concat_output_0 = Concat[axis = 0](%/Constant_1_output_0, %/Unsqueeze_output_0, %/Constant_2_output_0)%/ConstantOfShape_output_0 = ConstantOfShape[value = <Tensor>](%/Concat_output_0)%/Cast_output_0 = Cast[to = 1](%/ConstantOfShape_output_0)%/lstm/Transpose_output_0 = Transpose[perm = [1, 0, 2]](%input)%/lstm/Constant_output_0 = Constant[value = <Tensor>]()%/lstm/Constant_1_output_0 = Constant[value = <Tensor>]()%/lstm/Constant_2_output_0 = Constant[value = <Tensor>]()%/lstm/Slice_output_0 = Slice(%/Cast_output_0, %/lstm/Constant_1_output_0, %/lstm/Constant_2_output_0, %/lstm/Constant_output_0)%/lstm/Constant_3_output_0 = Constant[value = <Tensor>]()%/lstm/Constant_4_output_0 = Constant[value = <Tensor>]()%/lstm/Constant_5_output_0 = Constant[value = <Tensor>]()%/lstm/Slice_1_output_0 = Slice(%/Cast_output_0, %/lstm/Constant_4_output_0, %/lstm/Constant_5_output_0, %/lstm/Constant_3_output_0)%/lstm/LSTM_output_0, %/lstm/LSTM_output_1, %/lstm/LSTM_output_2 = LSTM[hidden_size = 10](%/lstm/Transpose_output_0, %onnx::LSTM_193, %onnx::LSTM_194, %onnx::LSTM_195, %, %/lstm/Slice_output_0, %/lstm/Slice_1_output_0)%/lstm/Constant_6_output_0 = Constant[value = <Tensor>]()%/lstm/Squeeze_output_0 = Squeeze(%/lstm/LSTM_output_0, %/lstm/Constant_6_output_0)%/lstm/Constant_7_output_0 = Constant[value = <Tensor>]()%/lstm/Constant_8_output_0 = Constant[value = <Tensor>]()%/lstm/Constant_9_output_0 = Constant[value = <Tensor>]()%/lstm/Slice_2_output_0 = Slice(%/Cast_output_0, %/lstm/Constant_8_output_0, %/lstm/Constant_9_output_0, %/lstm/Constant_7_output_0)%/lstm/Constant_10_output_0 = Constant[value = <Tensor>]()%/lstm/Constant_11_output_0 = Constant[value = <Tensor>]()%/lstm/Constant_12_output_0 = Constant[value = <Tensor>]()%/lstm/Slice_3_output_0 = Slice(%/Cast_output_0, %/lstm/Constant_11_output_0, %/lstm/Constant_12_output_0, %/lstm/Constant_10_output_0)%/lstm/LSTM_1_output_0, %/lstm/LSTM_1_output_1, %/lstm/LSTM_1_output_2 = LSTM[hidden_size = 10](%/lstm/Squeeze_output_0, %onnx::LSTM_213, %onnx::LSTM_214, %onnx::LSTM_215, %, %/lstm/Slice_2_output_0, %/lstm/Slice_3_output_0)%/lstm/Constant_13_output_0 = Constant[value = <Tensor>]()%/lstm/Squeeze_1_output_0 = Squeeze(%/lstm/LSTM_1_output_0, %/lstm/Constant_13_output_0)%/lstm/Transpose_1_output_0 = Transpose[perm = [1, 0, 2]](%/lstm/Squeeze_1_output_0)%/Constant_3_output_0 = Constant[value = <Scalar Tensor []>]()%output = Gather[axis = 1](%/lstm/Transpose_1_output_0, %/Constant_3_output_0)return %output
}
check model done.

C++调用ONNX

实现代码

vector<float> testOnnxLSTM(std::vector<std::vector<std::vector<float>>>& inputs) 
{//设置为VERBOSE,方便控制台输出时看到是使用了cpu还是gpu执行//Ort::Env env(ORT_LOGGING_LEVEL_VERBOSE, "test");Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "Default");Ort::SessionOptions session_options;session_options.SetIntraOpNumThreads(5); // 使用五个线程执行op,提升速度// 第二个参数代表GPU device_id = 0,注释这行就是cpu执行//OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0);session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);#ifdef _WIN32const wchar_t* model_path = L"C:\\Users\\xxx\\Desktop\\LSTM.onnx";#elseconst char* model_path = "C:\\Users\\xxx\\Desktop\\LSTM.onnx";#endifwprintf(L"%s\n", model_path);Ort::Session session(env, model_path, session_options);const char* input_names[] = { "input" }; const char* output_names[] = { "output" };const int input_size = 10;const int output_size = 10;const int batch_size = 1;const int seq_len = 10;std::array<float, batch_size* seq_len* input_size> input_matrix;std::array<float, batch_size* output_size> output_matrix;std::array<int64_t, 3> input_shape{ batch_size, seq_len, input_size };std::array<int64_t, 2> output_shape{ batch_size, output_size };for (int i = 0; i < batch_size; i++)for (int j = 0; j < seq_len; j++)for (int k = 0; k < input_size; k++)input_matrix[i * seq_len * input_size + j * input_size + k] = inputs[i][j][k];Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_matrix.data(), input_matrix.size(), input_shape.data(), input_shape.size());try{Ort::Value output_tensor = Ort::Value::CreateTensor<float>(memory_info, output_matrix.data(), output_matrix.size(), output_shape.data(), output_shape.size());session.Run(Ort::RunOptions{ nullptr }, input_names, &input_tensor, 1, output_names, &output_tensor, 1); }catch (const std::exception& e){std::cout << e.what() << std::endl;}std::cout << "get data from LSTM onnx: \n";vector<float> ret;for (int i = 0; i < output_size; i++) {ret.emplace_back(output_matrix[i]);std::cout << ret[i] << "\t";}std::cout << "\n";return ret;
}

调用代码

   std::vector<std::vector<std::vector<float>>> data;for (int i = 0; i < 1; i++) {std::vector<std::vector<float>> t1;for (int j = 0; j < 10; j++) {std::vector<float> t2;for (int k = 0; k < 10; k++) {t2.push_back(1.0 * k * j / 20);}t1.push_back(t2);}data.push_back(t1);}for (auto& i : data) {for (auto& j : i) {for (auto& k : j) {std::cout << k << "\t";}std::cout << "\n";}std::cout << "\n";}auto ret = testOnnxLSTM(data);

测试结果

0       0       0       0       0       0       0       0       0       0
0       0.05    0.1     0.15    0.2     0.25    0.3     0.35    0.4     0.45
0       0.1     0.2     0.3     0.4     0.5     0.6     0.7     0.8     0.9
0       0.15    0.3     0.45    0.6     0.75    0.9     1.05    1.2     1.35
0       0.2     0.4     0.6     0.8     1       1.2     1.4     1.6     1.8
0       0.25    0.5     0.75    1       1.25    1.5     1.75    2       2.25
0       0.3     0.6     0.9     1.2     1.5     1.8     2.1     2.4     2.7
0       0.35    0.7     1.05    1.4     1.75    2.1     2.45    2.8     3.15
0       0.4     0.8     1.2     1.6     2       2.4     2.8     3.2     3.6
0       0.45    0.9     1.35    1.8     2.25    2.7     3.15    3.6     4.05C:\Users\xxx\Desktop\LSTM.onnx
get data from LSTM onnx:
0.000401703 0.00102207 0.0011015 -0.000503412 -0.000911839 -0.0011367 -0.000309185 0.000591398 -0.000362981 -4.81475e-05

http://www.ppmy.cn/news/1115950.html

相关文章

element-table 行的拖拽更改顺序(无需下载sortableJs

样例展示&#xff1a;vueelement 通过阅读element文档我们发现element并不提供拖拽相关的api 本博客通过element提供的行类名 注册函数 实现行与行的拖拽 1.设置el-table 的行样式类名 这里是用的是 function <el-table:data"outputData":row-class-name&qu…

【算法挨揍日记】day06——1004. 最大连续1的个数 III、1658. 将 x 减到 0 的最小操作数

1004. 最大连续1的个数 III 1004. 最大连续1的个数 III 题目描述&#xff1a; 给定一个二进制数组 nums 和一个整数 k&#xff0c;如果可以翻转最多 k 个 0 &#xff0c;则返回 数组中连续 1 的最大个数 。 解题思路&#xff1a; 首先题目要我们求出的最多翻转k个0后&#x…

java复习--day4 (三目运算符、while、for循环)

文章目录 今天的内容1.三目运算符2.循环结构2.1为啥会有循环结构2.2while循环2.3do-while【几乎不用】2.4for循环【重点】2.5循环的嵌套 1.jdk安装和环境变量地配置 2.使用notepad书写第一个Java代码 class HelloWorld {public static void main (String[] args) {} } 3.会使用…

手机端ssh工具

工欲善其事必先利其器&#xff0c;我们在日常工作中需要登录服务器。在Pc端工具比较丰富&#xff0c;如Xshell等。而在手机端有没有好用的ssh连接工具呢&#xff1f; 关于 flutter_server_box一个 Flutter 项目&#xff0c;它提供图表来显示 Linux 服务器状态和管理服务器的工…

【使用malloc函数动态模拟开辟二维数组的三种方法】

方法1.用指针数&#x1f9d0; 首先&#xff1a;看一下原理图(以开辟整型二维数组三行四列为例&#xff0c;以下都是):&#x1f4bb; 其次&#xff1a; 先看一下用malloc申请一维数组:&#x1f92f; int *p(int *)malloc(10*sizeof(int));//开辟10个内存空间接着&#xff1a;申…

电脑篇——Windows/Ubuntu系统一些有趣的终端命令

内容持续维护更新中...... 一、Windows 1. 创建占位文件&#xff0c;在Windows Cmd窗口运行 fsutil file createnew <文件名称> <文件大小B> 例如 fsutil file createnew myFile 10240 二、Ubuntu 1. 查看某个目录下所有文件(夹)的空间使用情况 sudo du -…

嵌入式C语言知识复习和提高

文章目录 前言基础知识main函数防BUG注释&#xff08;重要&#xff09;关键字标识符命名&#xff08;驼峰命名&#xff09;常量类型变量printf1.输出不同类型数据2.输出不同宽度数据3.不同类型数据长度归类 scanf函数运算符sizeof&#xff08;运算符&#xff0c;优先级2&#x…

腾讯云OCR - 降低客服财务运营成本

说明&#xff1a;参与中秋活动 一、 前言&#xff1a; 随着图片时代的飞速发展&#xff0c;大量的文字内容为了优化排版和表现效果&#xff0c;都采用了图片的形式发布和存储&#xff0c;这为内容的传播和安全性带来了很大的便利&#xff0c;需要做重复性劳动。 OCR文字扫描工…