经典分类模型回顾5—DenseNet实现图像分类(matlab)

news/2024/11/3 4:20:11/

DenseNet,全称为Densely Connected Convolutional Networks,中文名为密集连接卷积网络,是由李沐等人在2017年提出的一种深度神经网络架构。 

DenseNet旨在解决深度神经网络中的梯度消失问题和参数数量过多的问题,通过构建密集连接的方式,使得网络能够更好地利用之前的特征,从而获得更好的性能。DenseNet的核心思想是:把网络中前面的层与后面的层进行连接,让前面的层的输出成为后面的层的输入。这样,整个卷积网络就变得非常紧凑,同时也避免了梯度消失的问题。

DenseNet的优点在于:参数少、计算速度快、准确率高。因此,DenseNet在图像识别、目标检测、图像分割等任务中都取得了很好的表现。

DenseNet是一种深度神经网络架构,它具有特殊的连接方式,可以有效地减少网络中的参数量,提高模型的准确性和稳定性。在图像分类任务中,DenseNet常常被使用。

在MATLAB中,可以使用深度学习工具箱来搭建和训练DenseNet模型。下面是一个简单的例子,展示如何使用深度学习工具箱来训练一个DenseNet模型进行CIFAR-10图像分类。

1. 准备数据

首先需要下载CIFAR-10数据集,可以使用MATLAB自带的数据集下载工具来获取数据集。

```MATLAB
cifar10Data = fullfile(tempdir, 'cifar-10-matlab');
if ~exist(cifar10Data, 'dir')
    cifar10Data = fullfile(toolboxdir('vision'), 'visiondata', 'cifar10');
    if ~exist(cifar10Data, 'dir')
        mkdir(cifar10Data);
        url = 'https://www.cs.toronto.edu/~kriz/cifar-10-matlab.tar.gz';
        helperCIFAR10Data.download(url, cifar10Data);
    end
end
```

2. 加载数据

使用 `imageDatastore` 函数将数据加载到MATLAB中。在此过程中,可以对图像进行增强处理,以提高模型的训练效果。

```MATLAB
% Load training and test data
[trainingImages, trainingLabels, testImages, testLabels] = helperCIFAR10Data.load(cifar10Data);

% Construct an imageDatastore object
trainingSet = imageDatastore(trainingImages, ...
    'labels', trainingLabels, ...
    'ReadFcn', @helperCIFAR10Data.readFunction);

testSet = imageDatastore(testImages, ...
    'labels', testLabels, ...
    'ReadFcn', @helperCIFAR10Data.readFunction);

% Prepare the data for training
inputSize = [32 32 3];
numClasses = 10;

% Apply data augmentation
augmentedTrainingSet = augmentedImageDatastore(inputSize, ...
    trainingSet, ...
    'ColorPreprocessing', 'gray2rgb', ...
    'RandCropSize', [28 28], ...
    'RandCropType', 'random', ...
    'RandRotation', [-8 8], ...
    'RandXReflection', true);
```

3. 构建DenseNet模型

使用 `densenet201` 函数从深度学习工具箱中加载DenseNet-201模型。

```MATLAB
net = densenet201;
```

可以使用 `analyzeNetwork` 函数来可视化模型架构。

```MATLAB
analyzeNetwork(net);
```

4. 训练模型

使用 `trainingOptions` 函数来配置训练选项。

```MATLAB
options = trainingOptions('sgdm', ...
    'InitialLearnRate', 0.001, ...
    'MaxEpochs', 50, ...
    'MiniBatchSize', 128, ...
    'VerboseFrequency', 50, ...
    'Plots', 'training-progress');
```

使用 `trainNetwork` 函数来训练模型。

```MATLAB
trainedNet = trainNetwork(augmentedTrainingSet, net, options);
```

5. 测试模型

使用 `classify` 函数来进行分类。

```MATLAB
predictedLabels = classify(trainedNet, testSet);
accuracy = mean(predictedLabels == testSet.Labels)
```

6. 可视化结果

使用 `montage` 函数来可视化测试集中的前20张图像及其分类结果。

```MATLAB
numImages = 20;
idx = randsample(numel(testSet.Files), numImages);
figure
montage(testSet.Files(idx), 'Size', [4 5]);
title('Test Images');

predictedLabels = classify(trainedNet, testSet);
label = cellstr(predictedLabels);
label = strcat(label, ", ", cellstr(num2str(testSet.Labels)));
groundTruth = cellstr(label);
groundTruth = strcat("Ground Truth: ", groundTruth);

predicted = cellstr(predictedLabels);
predicted = strcat("Prediction: ", predicted);

for i = 1:numImages
    text(i*32-25,32+10,groundTruth(idx(i)),'FontSize',8)
    text(i*32-25,32+20,predicted(idx(i)),'FontSize',8)
end
```

完整代码如下:

```MATLAB
cifar10Data = fullfile(tempdir, 'cifar-10-matlab');
if ~exist(cifar10Data, 'dir')
    cifar10Data = fullfile(toolboxdir('vision'), 'visiondata', 'cifar10');
    if ~exist(cifar10Data, 'dir')
        mkdir(cifar10Data);
        url = 'https://www.cs.toronto.edu/~kriz/cifar-10-matlab.tar.gz';
        helperCIFAR10Data.download(url, cifar10Data);
    end
end

[trainingImages, trainingLabels, testImages, testLabels] = helperCIFAR10Data.load(cifar10Data);

% Construct an imageDatastore object
trainingSet = imageDatastore(trainingImages, ...
    'labels', trainingLabels, ...
    'ReadFcn', @helperCIFAR10Data.readFunction);

testSet = imageDatastore(testImages, ...
    'labels', testLabels, ...
    'ReadFcn', @helperCIFAR10Data.readFunction);

% Prepare the data for training
inputSize = [32 32 3];
numClasses = 10;

% Apply data augmentation
augmentedTrainingSet = augmentedImageDatastore(inputSize, ...
    trainingSet, ...
    'ColorPreprocessing', 'gray2rgb', ...
    'RandCropSize', [28 28], ...
    'RandCropType', 'random', ...
    'RandRotation', [-8 8], ...
    'RandXReflection', true);

% Load pre-trained DenseNet-201 network
net = densenet201;

% Configure training options
options = trainingOptions('sgdm', ...
    'InitialLearnRate', 0.001, ...
    'MaxEpochs', 50, ...
    'MiniBatchSize', 128, ...
    'VerboseFrequency', 50, ...
    'Plots', 'training-progress');

% Train the network
trainedNet = trainNetwork(augmentedTrainingSet, net, options);

% Test the network
predictedLabels = classify(trainedNet, testSet);
accuracy = mean(predictedLabels == testSet.Labels)

% Visualize the results
numImages = 20;
idx = randsample(numel(testSet.Files), numImages);
figure
montage(testSet.Files(idx), 'Size', [4 5]);
title('Test Images');

predictedLabels = classify(trainedNet, testSet);
label = cellstr(predictedLabels);
label = strcat(label, ", ", cellstr(num2str(testSet.Labels)));
groundTruth = cellstr(label);
groundTruth = strcat("Ground Truth: ", groundTruth);

predicted = cellstr(predictedLabels);
predicted = strcat("Prediction: ", predicted);

for i = 1:numImages
    text(i*32-25,32+10,groundTruth(idx(i)),'FontSize',8)
    text(i*32-25,32+20,predicted(idx(i)),'FontSize',8)
end
```


http://www.ppmy.cn/news/30115.html

相关文章

滚蛋吧,正则表达式!

大家好,我是良许。 不知道大家有没有被正则表达式支配过的恐惧?看着一行火星文一样的表达式,虽然每一个字符都认识,但放在一起直接就让人蒙圈了~ 你是不是也有这样的操作,比如你需要使用「电子邮箱正则表达式」&…

vue简介与环境搭建

该栏目会非定期出教学视频,文档一般与视频关联,感谢观看。b站搜索博主同名,可观看教学视频。 一、Node.js 什么是node?是一个基于Chrome V8引擎的JavaScript运行环境,使用了一个事件驱动、非阻塞式I/O模型&#xff0c…

时序预测 | MATLAB实现IWOA-BiLSTM和BiLSTM时间序列预测(改进的鲸鱼算法优化双向长短期记忆神经网络)

时序预测 | MATLAB实现IWOA-BiLSTM和BiLSTM时间序列预测(改进的鲸鱼算法优化双向长短期记忆神经网络) 目录时序预测 | MATLAB实现IWOA-BiLSTM和BiLSTM时间序列预测(改进的鲸鱼算法优化双向长短期记忆神经网络)预测效果基本介绍程序设计参考资料预测效果 基本介绍 MATLAB实现IWO…

【算法】DFS与BFS

作者:指针不指南吗 专栏:算法篇 🐾题目的模拟很重要!!🐾 文章目录1.区别2.DFS2.1 排列数字2.2 n-皇后问题3.BFS3.1走迷宫1.区别 搜索类型数据结构空间用途过程DFSstackO( n )不能用于最短路搜索到最深处&a…

基于Java的某石材公司货物管理系统的设计与实现

技术:Java、JSP等摘要:随着信息化技术的发展,计算机的应用已迅速扩展到企事业管理与办公自动化领域,而数据库技术也被广泛应用。电脑操作及管理日趋简化,电脑知识日趋普及,同时市场经济快速多变、竞争激烈&…

Spark使用Log4j将日志发送到Kafka

文章目录自定义KafkaAppender修改log4j.properties配置启动命令配置添加参数启动之后可以在Kafka中查询发送数据时区问题-自定义实现JSONLayout解决自定义JSONLayout.java一键应用可能遇到的异常ClassNotFoundException: xxx.KafkaLog4jAppenderUnexpected problem occured dur…

软工2023个人作业二——软件案例分析

项目内容这个作业属于哪个课程2023年北航敏捷软件工程这个作业的要求在哪里个人作业-软件案例分析我在这个课程的目标是学习并掌握现代软件开发和项目管理技术,体验敏捷开发工作流程这个作业在哪个具体方面帮助我实现目标从软件工程角度分析比较我们所熟悉的软件&am…

做程序界中的死神,锻造合适的斩魂刀

标题解读:标题中的死神,是源自《死神》动漫里面的角色,斩魂刀是死神的武器,始解是斩魂刀的初始解放形态,卐解是斩魂刀的觉醒解放形态,也是死神的大招。意旨做程序界中程序员的佼佼者,一步一步最…