C/C++开发,opencv-ml库学习,支持决策树(DTrees)应用

ops/2024/9/24 23:24:52/

目录

DTrees%EF%BC%89-toc" style="margin-left:80px;">一、决策树DTrees

1.1决策树算法简介

1.2 OpenCV决策树

DTrees%20%E5%BA%93%E5%BA%94%E7%94%A8-toc" style="margin-left:80px;">二、cv::ml::DTrees 库应用

2.1 训练及样本数据准备

2.2 程序编译

2.3 main.cpp全代码


DTrees%EF%BC%89" style="background-color:transparent;">一、决策树DTrees

1.1决策树算法简介

        决策树(Decision Tree)是一种基本的分类与回归方法。它主要是通过对训练数据进行归纳学习来形成一棵包含内部节点和叶子节点的树。内部节点表示一个特征或属性的判断条件,叶子节点表示一个类别或者某个具体的值。决策树的学习过程包括特征选择、决策树的生成和决策树的剪枝。

        在分类问题中,决策树通过一系列的判断条件将样本划分到不同的类别中。每个内部节点对应一个特征属性的判断,根据该特征属性的不同取值,将样本划分到不同的子节点中。这个过程一直递归进行,直到达到叶子节点,即样本被划分到某个具体的类别中。

        决策树也存在一些缺点:

  1. 容易过拟合:如果决策树的深度过大,或者对训练数据过于敏感,可能会导致过拟合现象,即模型在训练数据上表现很好,但在测试数据上表现较差。
  2. 对连续特征的处理能力较弱决策树在处理连续特征时需要进行离散化处理,这可能会损失一些信息。

        为了提高决策树的性能,通常会采用一些改进方法,如集成学习(如随机森林、梯度提升决策树等)和剪枝技术(通过删除一些子树或叶子节点来简化模型,防止过拟合)。

1.2 OpenCV决策树

   cv::ml::DTrees 是 OpenCV(Open Source Computer Vision Library)中机器学习模块的一个类,它代表决策树(Decision Trees)。决策树是一种非参数监督学习方法,用于分类和回归。在分类问题中,决策树学习一个从输入特征到输出类别标签的映射。在回归问题中,它学习一个从输入特征到连续输出值的映射。

        在 OpenCV 中使用 cv::ml::DTrees 类,你可以训练决策树模型,并用它来预测新数据的类别或值。

Opencv中的DTrees类cv::Ptr<cv::ml::DTrees> dtree = cv::ml::DTrees::create();
模型参数:1. setMaxDepth(int) 必须 树最大深度 输入参数为正整数,这个参数限制了树的最大层数,有助于防止过拟合。2. setCVFolds(int) 必须 交叉验证 一般为03. setUseSurrogates(bool) 非必须 是否建立代替分裂点 输入 bool4. setMinSampleCount(int) 非必须 节点最小样本数 当样本数量过小 则不细分5. setUselSERule(bool) 非必须 表示是否严格剪枝 6. setTruncatePrunedTree(bool) 非必须 分支是否完全移除7. setRegressionAccuracy(float) 非必须 对于回归树,这个参数决定了何时停止划分。当节点的响应值的最大偏差小于这个值时,划分会停止。8. setMaxCategories(int) 非必须 对于分类问题,该参数指定了用于找到最优分割点的最大类别数。当特征的可能取值超过这个值时,会构建决策树的子树。

DTrees%20%E5%BA%93%E5%BA%94%E7%94%A8">二、cv::ml::DTrees 库应用

2.1 训练及样本数据准备

        参考本专栏博文《C/C++开发,opencv-ml库学习,支持向量机(SVM)应用-CSDN博客》下载MNIST 数据集(手写数字识别),并解压。

        同时参考该博文“2.4 SVM(支持向量机)实时识别应用”的章节资料,利用python代码解压t10k-images.idx3-ubyte出图片数据文件。

2.2 DTrees 应用

        创建了一个 cv::ml::DTrees 对象,并设置了训练数据和终止条件。接着,我们调用 train 方法来训练决策树模型。最后,我们使用训练好的模型来预测一个新样本的类别。

// 创建决策树对象  cv::Ptr<cv::ml::DTrees> dtree = cv::ml::DTrees::create();  dtree->setMaxDepth(30);          // 设置树的最大深度  dtree->setCVFolds(0);dtree->setMinSampleCount(1);   // 设置分裂内部节点所需的最小样本数 std::cout << "create dtree object finish!" << std::endl;  // 训练决策树--trainingData训练数据,labelsMat训练标签cv::Ptr<cv::ml::TrainData> td = cv::ml::TrainData::create(trainingData, cv::ml::ROW_SAMPLE, labelsMat);  std::cout << "create TrainData object finish!" << std::endl; if(dtree->train(td)){std::cout << "dtree training finish!" << std::endl;}else{std::cout << "dtree training fail!" << std::endl; }......cv::Mat testResp;float response = dtree->predict(testData,testResp); 

        训练完模型化保存输出。PS:本文是基于前往支持向量机(SVM)代码快速验证,有些名词会较怪异。

dtree->save("mnist_svm.xml");

        然后通道调用mnist_svm.xml,实现决策树的分类识别

cv::Ptr<cv::ml::DTrees> dtree = cv::ml::StatModel::load<cv::ml::DTrees>("mnist_svm.xml");
......
//预测图片
float ret = dtree->predict(image);
std::cout << "predict val = "<< ret << std::endl;
2.2 程序编译

        和前一篇讲述支持向量机(SVM)应用的博文编译类似,采用opencv+mingw+makefile方式编译:

#/bin/sh
#win32
CX= g++ -DWIN32 
#linux
#CX= g++ -Dlinux BIN 		:= ./
TARGET      := opencv_ml02.exe
FLAGS		:= -std=c++11 -static
SRCDIR 		:= ./
#INCLUDES
INCLUDEDIR 	:= -I"../../opencv_MinGW/include" -I"./"
#-I"$(SRCDIR)"
staticDir   := ../../opencv_MinGW/x64/mingw/staticlib/
#LIBDIR		:= $(staticDir)/libopencv_world460.a\
#			   $(staticDir)/libade.a \
#			   $(staticDir)/libIlmImf.a \
#			   $(staticDir)/libquirc.a \
#			   $(staticDir)/libzlib.a \
#			   $(wildcard $(staticDir)/liblib*.a) \
#			   -lgdi32 -lComDlg32 -lOleAut32 -lOle32 -luuid 
#opencv_world放弃前,然后是opencv依赖的第三方库,后面的库是MinGW编译工具的库LIBDIR 	    := -L $(staticDir) -lopencv_world460 -lade -lIlmImf -lquirc -lzlib \-llibjpeg-turbo -llibopenjp2 -llibpng -llibprotobuf -llibtiff -llibwebp \-lgdi32 -lComDlg32 -lOleAut32 -lOle32 -luuid 
source		:= $(wildcard $(SRCDIR)/*.cpp) $(TARGET) :$(CX) $(FLAGS) $(INCLUDEDIR) $(source)  -o $(BIN)/$(TARGET) $(LIBDIR)clean:rm  $(BIN)/$(TARGET)

        make编译,make clean 清除可重新编译。

        运行效果,同样数据样本,相比svm算法训练结果,其准确率不算高,大家可以尝试调整参数验证:

2.3 main.cpp全代码

        main.cpp源代码,由于是基于前一篇博文《C/C++开发,opencv-ml库学习,支持向量机(SVM)应用-CSDN博客》快速移用实现的,有很多支持向量机(SVM)应用的痕迹,采用的数据样本也非较合适的,仅仅是为了阐述c++ opencv 决策树DTrees)应用说明。

#include <opencv2/opencv.hpp>  
#include <opencv2/ml/ml.hpp>  
#include <opencv2/imgcodecs.hpp>
#include <iostream>  
#include <vector>  
#include <iostream>
#include <fstream>int intReverse(int num)
{return (num>>24|((num&0xFF0000)>>8)|((num&0xFF00)<<8)|((num&0xFF)<<24));
}std::string intToString(int num)
{char buf[32]={0};itoa(num,buf,10);return std::string(buf);
}cv::Mat read_mnist_image(const std::string fileName) {int magic_number = 0;int number_of_images = 0;int img_rows = 0;int img_cols = 0;cv::Mat DataMat;std::ifstream file(fileName, std::ios::binary);if (file.is_open()){std::cout << "open images file: "<< fileName << std::endl;file.read((char*)&magic_number, sizeof(magic_number));//formatfile.read((char*)&number_of_images, sizeof(number_of_images));//images numberfile.read((char*)&img_rows, sizeof(img_rows));//img rowsfile.read((char*)&img_cols, sizeof(img_cols));//img colsmagic_number = intReverse(magic_number);number_of_images = intReverse(number_of_images);img_rows = intReverse(img_rows);img_cols = intReverse(img_cols);std::cout << "format:" << magic_number<< " img num:" << number_of_images<< " img row:" << img_rows<< " img col:" << img_cols << std::endl;std::cout << "read img data" << std::endl;DataMat = cv::Mat::zeros(number_of_images, img_rows * img_cols, CV_32FC1);unsigned char temp = 0;for (int i = 0; i < number_of_images; i++) {for (int j = 0; j < img_rows * img_cols; j++) {file.read((char*)&temp, sizeof(temp));//svm data is CV_32FC1float pixel_value = float(temp);DataMat.at<float>(i, j) = pixel_value;}}std::cout << "read img data finish!" << std::endl;}file.close();return DataMat;
}cv::Mat read_mnist_label(const std::string fileName) {int magic_number;int number_of_items;cv::Mat LabelMat;std::ifstream file(fileName, std::ios::binary);if (file.is_open()){std::cout << "open label file: "<< fileName << std::endl;file.read((char*)&magic_number, sizeof(magic_number));file.read((char*)&number_of_items, sizeof(number_of_items));magic_number = intReverse(magic_number);number_of_items = intReverse(number_of_items);std::cout << "format:" << magic_number << "  ;label_num:" << number_of_items << std::endl;std::cout << "read Label data" << std::endl;//data type:CV_32SC1,channel:1LabelMat = cv::Mat::zeros(number_of_items, 1, CV_32SC1);for (int i = 0; i < number_of_items; i++) {unsigned char temp = 0;file.read((char*)&temp, sizeof(temp));LabelMat.at<unsigned int>(i, 0) = (unsigned int)temp;}std::cout << "read label data finish!" << std::endl;}file.close();return LabelMat;
}//change path for real paths
std::string trainImgFile = "D:\\workForMy\\OpenCVLib\\opencv_demo\\opencv_ml01\\train-images.idx3-ubyte";
std::string trainLabeFile = "D:\\workForMy\\OpenCVLib\\opencv_demo\\opencv_ml01\\train-labels.idx1-ubyte";
std::string testImgFile = "D:\\workForMy\\OpenCVLib\\opencv_demo\\opencv_ml01\\t10k-images.idx3-ubyte";
std::string testLabeFile = "D:\\workForMy\\OpenCVLib\\opencv_demo\\opencv_ml01\\t10k-labels.idx1-ubyte";void train_SVM()
{//read train images, data type CV_32FC1cv::Mat trainingData = read_mnist_image(trainImgFile);//images data normalizationtrainingData = trainingData/255.0;std::cout << "trainingData.size() = " << trainingData.size() << std::endl;std::cout << "trainingData.type() = " << trainingData.type() << std::endl;  std::cout << "trainingData.rows = " << trainingData.rows << std::endl; std::cout << "trainingData.cols = " << trainingData.cols << std::endl; //read train label, data type CV_32SC1cv::Mat labelsMat = read_mnist_label(trainLabeFile);std::cout << "labelsMat.size() = " << labelsMat.size() << std::endl; std::cout << "labelsMat.type() = " << labelsMat.type() << std::endl;  std::cout << "labelsMat.rows = " << labelsMat.rows << std::endl; std::cout << "labelsMat.cols = " << labelsMat.cols << std::endl; std::cout << "trainingData & labelsMat finish!" << std::endl;  // //create SVM model// cv::Ptr<cv::ml::SVM> svm = cv::ml::SVM::create();  // //set svm args,type and KernelTypes// svm->setType(cv::ml::SVM::C_SVC);  // svm->setKernel(cv::ml::SVM::POLY);  // //KernelTypes POLY is need set gamma and degree// svm->setGamma(3.0);// svm->setDegree(2.0);// //Set iteration termination conditions, maxCount is importance// svm->setTermCriteria(cv::TermCriteria(cv::TermCriteria::EPS | cv::TermCriteria::COUNT, 1000, 1e-8)); // std::cout << "create SVM object finish!" << std::endl;  // std::cout << "trainingData.rows = " << trainingData.rows << std::endl; // std::cout << "trainingData.cols = " << trainingData.cols << std::endl; // std::cout << "trainingData.type() = " << trainingData.type() << std::endl; // // svm model train // svm->train(trainingData, cv::ml::ROW_SAMPLE, labelsMat);  // std::cout << "SVM training finish!" << std::endl; // 创建决策树对象  cv::Ptr<cv::ml::DTrees> dtree = cv::ml::DTrees::create();  dtree->setMaxDepth(30);          // 设置树的最大深度  dtree->setCVFolds(0);dtree->setMinSampleCount(1);   // 设置分裂内部节点所需的最小样本数 std::cout << "create dtree object finish!" << std::endl;  // 训练决策树--trainingData训练数据,labelsMat训练标签cv::Ptr<cv::ml::TrainData> td = cv::ml::TrainData::create(trainingData, cv::ml::ROW_SAMPLE, labelsMat);  std::cout << "create TrainData object finish!" << std::endl; if(dtree->train(td)){std::cout << "dtree training finish!" << std::endl;}else{std::cout << "dtree training fail!" << std::endl; }// svm model test  cv::Mat testData = read_mnist_image(testImgFile);//images data normalizationtestData = testData/255.0;std::cout << "testData.rows = " << testData.rows << std::endl; std::cout << "testData.cols = " << testData.cols << std::endl; std::cout << "testData.type() = " << testData.type() << std::endl; //read test label, data type CV_32SC1cv::Mat testlabel = read_mnist_label(testLabeFile);cv::Mat testResp;// float response = svm->predict(testData,testResp); float response = dtree->predict(testData,testResp); // std::cout << "response = " << response << std::endl; testResp.convertTo(testResp,CV_32SC1);int map_num = 0;for (int i = 0; i <testResp.rows&&testResp.rows==testlabel.rows; i++){if (testResp.at<int>(i, 0) == testlabel.at<int>(i, 0)){map_num++;}// else{// 	std::cout << "testResp.at<int>(i, 0) " << testResp.at<int>(i, 0) << std::endl;// 	std::cout << "testlabel.at<int>(i, 0) " << testlabel.at<int>(i, 0) << std::endl;// }}float proportion  = float(map_num) / float(testResp.rows);std::cout << "map rate: " << proportion * 100 << "%" << std::endl;std::cout << "SVM testing finish!" << std::endl; //save svm model// svm->save("mnist_svm.xml");dtree->save("mnist_svm.xml");
}void prediction(const std::string fileName,cv::Ptr<cv::ml::DTrees> dtree)
// void prediction(const std::string fileName,cv::Ptr<cv::ml::SVM> svm)
{//read img 28*28 sizecv::Mat image = cv::imread(fileName, cv::IMREAD_GRAYSCALE);//uchar->float32image.convertTo(image, CV_32F);//image data normalizationimage = image / 255.0;//28*28 -> 1*784image = image.reshape(1, 1);//预测图片float ret = dtree->predict(image);std::cout << "predict val = "<< ret << std::endl;
}std::string imgDir = "D:\\workForMy\\OpenCVLib\\opencv_demo\\opencv_ml01\\t10k-images\\";
std::string ImgFiles[5] = {"image_0.png","image_10.png","image_20.png","image_30.png","image_40.png",};
void predictimgs()
{//load svm model// cv::Ptr<cv::ml::SVM> svm = cv::ml::StatModel::load<cv::ml::SVM>("mnist_svm.xml");//load DTrees modelcv::Ptr<cv::ml::DTrees> dtree = cv::ml::StatModel::load<cv::ml::DTrees>("mnist_svm.xml");for (size_t i = 0; i < 5; i++){prediction(imgDir+ImgFiles[i],dtree);}
}int main()  
{  train_SVM();predictimgs();	return 0;  
}


http://www.ppmy.cn/ops/16495.html

相关文章

jsp+springboot+java二手车交易管理系统258u6

设计而成的系统要有以下目标&#xff1a;管理员和用户能够跳转到不同的页面当中。因此要把系统的目标设置为如下几项&#xff1a; (1) 系统在操作上不能过于复杂。 (2) 用户对应着不同的角色 (3) 设计完成的数据库要有能够处理并发和安全的作用 (4) 设计完成的管理…

TiDB系列之:TiCDC使用Changefeed完成数据同步任务

TiDB系列之:TiCDC使用Changefeed完成数据同步任务 一、Changefeed二、Changefeed 状态流转三、操作Changefeed四、cdc cli管理同步任务1.创建同步任务2.查询同步任务列表3.查询特定同步任务4.停止同步任务5.恢复同步任务6.删除同步任务7.更新同步任务配置8.管理同步子任务处理…

【原创】springboot+mysql企业智慧办公OA管理系统

个人主页&#xff1a;程序猿小小杨 个人简介&#xff1a;从事开发多年&#xff0c;Java、Php、Python、前端开发均有涉猎 博客内容&#xff1a;Java项目实战、项目演示、技术分享 文末有作者名片&#xff0c;希望和大家一起共同进步&#xff0c;你只管努力&#xff0c;剩下的交…

用户中心 -- 代码理解

一、删除表 & if 删除表 1.1 DROP TABLE IF EXISTS user 和 DROP TABLE user 网址&#xff1a; 用户管理第2节课 -- idea 2023.2 创建表--【本人】-CSDN博客 二、 代码 2.1 清空表中数据 的 命令 【truncate 清空】 网址&#xff1a; 用户管理第2节课 -- idea 2…

SpringMVC基础篇(四)

文章目录 1.视图1.基本介绍1.视图介绍2.为什么需要自定义视图 2.自定义视图实例1.思路分析2.代码实例1.view.jsp2.接口3.配置自定义视图解析器springDispatcherServlet-servlet.xml4.自定义视图MyView.java5.view_result.jsp6.结果展示 3.自定义视图执行流程4.自定义视图执行流…

美团一面复活赛4/18

语言表述还是有问题。。fuck 1.自我介绍 2.项目中遇到哪些问题&#xff0c;怎么解决的 太菜了没答好。。面试官没为难&#xff08;购票扣减逻辑要复习一下&#xff09; 3.mq发生了数据积压怎么去解&#xff1f; 面试官感觉我不会&#xff0c;也没为难我 4.Integer20 Integer2…

深耕“星光电务”党建品牌 引领保障企业高质量发展

在日前闭幕的2024年首届全国企业党务工作者论坛中&#xff0c;中铁十一局集团电务工程有限公司提交的论文《深耕“星光电务”党建品牌 引领保障企业高质量发展》荣获优秀论文奖。该论文由陈柯、刘敏之、徐干、姜亦珂联合撰写&#xff0c;展示了他们在党建工作中的创新实践与显著…

【C++ 哈希应用】

文章目录 位图概念代码实现海量数据处理 布隆过滤器概念代码实现海量数据处理 哈希切割海量数据处理 位图 概念 一个值在给定的集合中有两种状态&#xff0c;在或不在&#xff0c;要表示这种状态&#xff0c;最少可以用一个比特位&#xff0c;比特位为1表示在&#xff0c;比特…