Canmv k230 C++案例1.2——image classify项目 C++代码分析(待完成)

ops/2024/10/21 0:58:41/

这部分为初学,所以手头最好有本工具书便于查阅

01 代码初步注释

// 这里是一些定义配置
// 时间的标准库
#include <chrono>
// 写入或读取文件的标准库
#include <fstream>
// 文件输入输出的标准库,流模型
#include <iostream>
// k230的头文件,用于AI模型推断
#include <nncase/runtime/interpreter.h>
#include <nncase/runtime/runtime_op_utility.h>
// opencv开关  及  预处理过程的开关
#define USE_OPENCV 1
#define preprocess 1// opencv 文件
#if USE_OPENCV
// 显示、编码、处理?待查手册验证
#include <opencv2/highgui.hpp>
#include <opencv2/imgcodecs.hpp>
#include <opencv2/imgproc.hpp>
#endif// nncase的命名空间
using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::runtime::detail;// 数据的输入尺寸定义 224*224*3
#define INTPUT_HEIGHT 224
#define INTPUT_WIDTH 224
#define INTPUT_CHANNELS 3

02

// 定理的类模版
template <class T>
// 读取二进制文件  输入参数为文件名
std::vector<T> read_binary_file(const std::string &file_name)
{// 打开数据流  构筑一个ifstream并打开给定文件 以二进制方式进行IOstd::ifstream ifs(file_name, std::ios::binary);// 从文件末尾开始搜寻ifs.seekg(0, ifs.end);// 获取矢量数据的长度size_t len = ifs.tellg();// 定义矢量,依赖与机器std::vector<T> vec(len / sizeof(T), 0);// 从文件开始的地方进行搜寻ifs.seekg(0, ifs.beg);// 强制类型转换,非常危险,但只有这个vector是正确的ifs.read(reinterpret_cast<char *>(vec.data()), len);// 关闭文件ifs.close();return vec;
}
// 读入文件,打开文件,判断文件长度,写入数据
void read_binary_file(const char *file_name, char *buffer)
{std::ifstream ifs(file_name, std::ios::binary);ifs.seekg(0, ifs.end);// sizeof 返回类型size_t len = ifs.tellg();ifs.seekg(0, ifs.beg);ifs.read(buffer, len);ifs.close();
}static std::vector<std::string> read_txt_file(const char *file_name)
{// 定义字符串vector变量vec,不含任何元素std::vector<std::string> vec;// 分类至少能容纳n个元素的内容空间vec.reserve(1024);// 打开file_name的文件名变量fpstd::ifstream fp(file_name);// 定义字符串变量labelstd::string label;// 从fp中读取一行赋给label,返回fp// 每次读如一整行,直至到达文件末尾while (getline(fp, label)){// 矢量后面增加数据vec.push_back(label);}return vec;
}

softmax函数,函数公式

template<typename T>
static int softmax(const T* src, T* dst, int length)
{// 算法函数 寻找数组的最大值 const赋值const T alpha = *std::max_element(src, src + length);// 分母T denominator{ 0 };for (int i = 0; i < length; ++i) {dst[i] = std::exp(src[i] - alpha);denominator += dst[i];}// 指数输出for (int i = 0; i < length; ++i) {dst[i] /= denominator;}return 0;
}

数据格式转换

// 调用OPENCV的函数转换
#if USE_OPENCV
// hwc转chw 转换以适应tensorflow?
std::vector<uint8_t> hwc2chw(cv::Mat &img)
{std::vector<uint8_t> vec;std::vector<cv::Mat> rgbChannels(3);cv::split(img, rgbChannels);for (auto i = 0; i < rgbChannels.size(); i++){std::vector<uint8_t> data = std::vector<uint8_t>(rgbChannels[i].reshape(1, 1));vec.insert(vec.end(), data.begin(), data.end());}return vec;
}
#endif

模型推断代码

// 可以看到推断需要执行文件和三个参数模型、图片、标签
static int inference(const char *kmodel_file, const char *image_file, const char *label_file)
{// load kmodelinterpreter interp;// 模型也被保存为二进制文件?但格式未知std::ifstream ifs(kmodel_file, std::ios::binary);// 判断是否载入正常interp.load_model(ifs).expect("load_model failed");// create input tensor 创建输入?auto input_desc = interp.input_desc(0);// create input tensor 创建输入尺寸?auto input_shape = interp.input_shape(0);auto input_tensor = host_runtime_tensor::create(input_desc.datatype, input_shape, hrt::pool_shared).expect("cannot create input tensor");interp.input_tensor(0, input_tensor).expect("cannot set input tensor");// create output tensor// auto output_desc = interp.output_desc(0);// auto output_shape = interp.output_shape(0);// auto output_tensor = host_runtime_tensor::create(output_desc.datatype, output_shape, hrt::pool_shared).expect("cannot create output tensor");// interp.output_tensor(0, output_tensor).expect("cannot set output tensor");// set input dataauto dst = input_tensor.impl()->to_host().unwrap()->buffer().as_host().unwrap().map(map_access_::map_write).unwrap().buffer();
#if USE_OPENCVcv::Mat img = cv::imread(image_file);cv::resize(img, img, cv::Size(INTPUT_WIDTH, INTPUT_HEIGHT), cv::INTER_NEAREST);auto input_vec = hwc2chw(img);memcpy(reinterpret_cast<char *>(dst.data()), input_vec.data(), input_vec.size());
#elseread_binary_file(image_file, reinterpret_cast<char *>(dst.data()));
#endifhrt::sync(input_tensor, sync_op_t::sync_write_back, true).expect("sync write_back failed");// runsize_t counter = 1;auto start = std::chrono::steady_clock::now();for (size_t c = 0; c < counter; c++){interp.run().expect("error occurred in running model");}auto stop = std::chrono::steady_clock::now();double duration = std::chrono::duration<double, std::milli>(stop - start).count();std::cout << "interp.run() took: " << duration / counter << " ms" << std::endl;// get output dataauto output_tensor = interp.output_tensor(0).expect("cannot set output tensor");dst = output_tensor.impl()->to_host().unwrap()->buffer().as_host().unwrap().map(map_access_::map_read).unwrap().buffer();float *output_data = reinterpret_cast<float *>(dst.data());auto out_shape = interp.output_shape(0);auto size = compute_size(out_shape);// postprogress softmax by cpustd::vector<float> softmax_vec(size, 0);auto buf = softmax_vec.data();softmax(output_data, buf, size);auto it = std::max_element(buf, buf + size);size_t idx = it - buf;// load labelauto labels = read_txt_file(label_file);std::cout << "image classify result: " << labels[idx] << "(" << *it << ")" << std::endl;return 0;
}

主函数

// 主函数 要判断各个环节是否正确输出
int main(int argc, char *argv[])
{// 输出argv[0]一般是文件名称std::cout << "case " << argv[0] << " built at " << __DATE__ << " " << __TIME__ << std::endl;if (argc != 4){// 判断输入argc个数std::cerr << "Usage: " << argv[0] << " <kmodel> <image> <label>" << std::endl;return -1;}int ret = inference(argv[1], argv[2], argv[3]);if (ret){std::cerr << "inference failed: ret = " << ret << std::endl;return -2;}return 0;
}

02 代码的整体结构

03 部分代码说明

04 附录 一些代码资料说明


http://www.ppmy.cn/ops/127133.html

相关文章

在Flask中记录用户端的完整访问记录,包括请求和响应信息以及用户访问IP

在Flask中记录用户端的完整访问记录&#xff0c;包括请求和响应信息以及用户访问IP&#xff0c;可以通过自定义中间件&#xff08;或称为请求预处理和后处理函数&#xff09;来实现。Flask本身提供了装饰器和信号机制来帮助我们实现这一功能。 以下是一个基本的实现步骤&#…

搭建LeNet-5神经网络,并搭建自己的图像分类训练和测试的模板,模板通用!!!均有详细注释。

本文任务&#xff1a; 1、构建LeNet神经网络。 2、搭建图像分类训练和测试的通用模板。 3、训练出自己的模型。 4、验证模型效果。 LeNet论文地址&#xff1a;原文地址http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks…

爬虫逆向学习(十二):一个案例入门补环境

此分享只用于学习用途&#xff0c;不作商业用途&#xff0c;若有冒犯&#xff0c;请联系处理 反爬前置信息 站点&#xff1a;aHR0cDovLzEyMC4yMTEuMTExLjIwNjo4MDkwL3hqendkdC94anp3ZHQvcGFnZXMvaW5mby9wb2xpY3k 接口&#xff1a;/xjzwdt/rest/xmzInfoDeliveryRest/getInfoDe…

git clone 鉴权失败

git clone 鉴权失败问题 1. 问题描述2. 解决方法 1. 问题描述 使用git clone自己的代码报如下错误&#xff1a; 正克隆到 xxx... Username for https://github.com: Password for https://xxxgithub.com: remote: Support for password authentication was removed on Augu…

【Flutter】页面布局:线性布局(Row 和 Column)

在 Flutter 中&#xff0c;布局&#xff08;Layout&#xff09;是应用开发的核心之一。通过布局组件&#xff0c;开发者可以定义应用中的控件如何在屏幕上排列。Row 和 Column 是 Flutter 中最常用的两种线性布局方式&#xff0c;用于水平和垂直排列子组件。在本教程中&#xf…

Unity中通过给定的顶点数组生成凸面体的方法参考

这里我们使用了Quickhull for Unity插件&#xff0c;其实就是一个ConvexHullCalculator.cs文件&#xff0c;代码如下&#xff1a; /*** Copyright 2019 Oskar Sigvardsson** Permission is hereby granted, free of charge, to any person obtaining a copy* of this software…

部署服务dockerfile失败小记

首先是碰到在ENTRYPOINT 中添加参数问题&#xff0c;一时懵逼直接添加到最后 ENTRYPOINT [“java”,“-jar”,“xx-xx.jar”,“–spring.profiles.active${envType}”,“–add-opens java.base/java.langALL-UNNAMED”,“–add-opens java.base/java.utilALL-UNNAMED”, “–a…

PHP政务招商系统——高效连接共筑发展蓝图

政务招商系统——高效连接&#xff0c;共筑发展蓝图 &#x1f3db;️ 一、政务招商系统&#xff1a;开启智慧招商新篇章 在当今经济全球化的背景下&#xff0c;政务招商成为了推动地方经济发展的重要引擎。而政务招商系统的出现&#xff0c;更是为这一进程注入了新的活力。它…