MATLAB深度学习(二)——如何训练一个卷积神经网路

ops/2024/11/19 11:08:40/

2.1 基本概念

        从数学的角度看,机器学习的目标是建立输入和输出的函数关系,相当于 y = F(x)的过程。F(x)就是我们所说的模型,对于使用者来说,这个模型就是一个黑箱,我们不知道其具体的结构,但是给定一个输出,就可以得到我们想要的结果。F(x)的获得,我们通过的是实验法啊,经过大量数据训练出来的,我们定义一个损失函数L(x),记录真实输出与模型输出的偏差,通过数据的迭代使得损失函数L(x)达到最小。

        在机器学习中,我们需要理解概念的术语的解释:

训练样本用于训练的数据
训练用于训练样本特征统计和归纳的过程
模型总结出的规律、标准
验证用于验证数据集评价模型是否准确
超参数学习速率、迭代层神经元个数等
参数权重、偏置等
泛化模型对新样本的适应力

        过拟合和欠拟合是常见的现象。但是需要说明的是,数据没有过多的这种说法,所谓的过拟合,是模型在训练集上的表现过于优异,模拟考100分你考了100分,99分,但是验证集上,相当于实际考试中你考了40分,换一场考试,换一个新的数据,导致严重误判。欠拟合就是数据过少,模型无法归纳出共性,在训练集和测试集表现都很差。

2.2 实例需求与实现步骤

        第一章里面我们用了工具箱来实现,这一章我们强化一下,用m文件编写,我们构建训练一个三层卷积神经网络,对输入的图像进行预测,计算器预测准确率和RMSE均方根误差。实现步骤具体参考第一章。

        

%% 步骤1:加载和显示图像数据
[XTrain,~,YTrain] = digitTrain4DArrayData;                       
[XValidation,~,YValidation] = digitTest4DArrayData;              % 随机显示20幅训练图像
numTrainImages = numel(YTrain);                                  
figure
idx = randperm(numTrainImages,20);
for i = 1:numel(idx)subplot(4,5,i)    imshow(XTrain(:,:,:,idx(i)))drawnow
end%% 步骤2:构建卷积神经网络
layers = [imageInputLayer([28 28 1])   %输入 像素为28*28,1通道                    convolution2dLayer(3,8,'Padding','same')   %卷积层1 卷积核大小为3*3,卷积核个数为8,卷积方式用0填充          batchNormalizationLayer      %归一化 加快训练网络时的收敛速度                           reluLayer                    %ReLU函数 激活函数                              averagePooling2dLayer(2,'Stride',2)         %平均池化  池化区域为2*2,步长为2              convolution2dLayer(3,16,'Padding','same')   %第二          batchNormalizationLayer                                        reluLayer                                                averagePooling2dLayer(2,'Stride',2)                        convolution2dLayer(3,32,'Padding','same')   %第三               batchNormalizationLayer                                    reluLayer                                               dropoutLayer(0.2)      %随机将20%的输入置0,防止过拟合                                     fullyConnectedLayer(1)         %全连接层输出个数为1                           regressionLayer ];         %用于预测结果                               %% 步骤3:配置训练选项
miniBatchSize = 128;  % 设置小批量的大小为 128
validationFrequency = floor(numel(YTrain)/miniBatchSize);  % 计算验证频率,根据训练数据的数量除以 miniBatchSize 并取整
% trainingOptions 用于配置网络训练的选项
options = trainingOptions('sgdm', ...  % 选择随机梯度下降动量法(SGDM)作为优化器'MiniBatchSize',miniBatchSize, ...  % 指定每次训练的小批量大小为 miniBatchSize'MaxEpochs',30, ...  % 设置训练的最大轮数为 30'InitialLearnRate',0.001, ...  % 设置初始学习率为 0.001'LearnRateSchedule','piecewise', ...  % 学习率调整方式为分段调整'LearnRateDropFactor',0.1, ...  % 每次学习率下降时,下降的比例为 0.1'LearnRateDropPeriod',20, ...  % 每 20 个周期调整一次学习率'Shuffle','every-epoch', ...  % 每轮训练后随机打乱数据'ValidationData',{XValidation,YValidation}, ...  % 指定验证数据为 XValidation 和 YValidation'ValidationFrequency',validationFrequency, ...  % 设置验证的频率'Plots','training-progress', ...  % 启用训练进度的动态绘图'Verbose',true);  % 打印详细的训练信息%% 步骤4:训练网络
net = trainNetwork(XTrain,YTrain,layers,options); % X训练集 Y测试集 网络结构 训练设置%% 步骤5:测试与评估
YPredicted = predict(net,XValidation);  % 使用训练好的网络对验证集数据进行预测
predictionError = YValidation - YPredicted;  % 计算预测误差(真实值减去预测值)% 计算准确率
thr = 10;  % 设置误差阈值为 10
numCorrect = sum(abs(predictionError) < thr);  % 统计误差绝对值小于阈值的预测数量
numValidationImages = numel(YValidation);  % 获取验证集样本总数
Accuracy = numCorrect/numValidationImages;  % 准确率计算为预测正确的样本数除以总样本数% 计算RMSE(均方根误差)的值
squares = predictionError.^2;  % 计算误差的平方
RMSE = sqrt(mean(squares));  % 求均值后开平方,得到均方根误差

         训练选项设置,读者可以进行打开帮助查看,里面还有很多内容,可以进行自动补全。

        因为我有GPU,这里就改成GPU进行训练了


http://www.ppmy.cn/ops/134951.html

相关文章

VSCode 常用的快捷键

Visual Studio Code (VSCode) 提供了丰富的快捷键来提高开发效率。 是常用的 VSCode 快捷键&#xff0c;按功能分类&#xff1a; 1. 基础编辑 Ctrl C / Ctrl V / Ctrl X&#xff1a;复制、粘贴、剪切当前选中的文本。Ctrl Z / Ctrl Y&#xff1a;撤销和重做操作。Ctrl …

简单的MCU与FPGA通过APB总线实现通讯(fpga mcu APB):乘法器为例

测试平台: GW1N4器件内置 M1内核;并且可以设置 APB总线与fpga 逻辑进行交互; 框图: +---------------------+ | | | M1 Microprocessor | <-----------------+ | | | | +-----------------…

DNS服务器Mac地址绑定与ip网路管理命令(Ubuntu24.04)

DNS server Mac绑定 查看 DNS服务器地址 resolvectl statusLink 2 (wlp2s0)Current Scopes: DNS Current DNS Server: 10.10.0.21DNS Servers: 10.10.0.21 10.10.2.21查看路由器中邻居表的内容&#xff0c;每一行表示一个网络设备的IP地址、MAC地址及其状态 ip neigh10.162.…

靓车汽车销售网站(源码+数据库+报告)

基于SpringBoot靓车汽车销售网站&#xff0c;系统包含两种角色&#xff1a;管理员、用户,系统分为前台和后台两大模块&#xff0c;主要功能如下。 前台功能简介&#xff1a; - 首页&#xff1a;展示网站的概要信息和推荐车辆。 - 车辆展示&#xff1a;展示可供销售的汽车。 - …

React状态管理之Redux

React状态管理之Redux 在React应用中&#xff0c;状态管理是一个至关重要的概念。随着应用规模的扩大&#xff0c;组件之间的状态共享和更新变得愈发复杂。Redux作为一个专门用于JavaScript应用&#xff08;尤其是React应用&#xff09;的状态管理库&#xff0c;提供了一种可预…

【Java 集合】Collections 空列表细节处理

问题 如下代码&#xff0c;虽然定义为非空 NonNull&#xff0c;但依然会返回空对象&#xff0c;导致调用侧被检测为空引用。 实际上不是Collections的问题是三目运算符返回了null对象。 import java.util.Collections;NonNullprivate List<String> getInfo() {IccReco…

ODC 如何精确呈现SQL耗时 | OceanBase 开发者工具解析

前言 在程序员或DBA的日常工作中&#xff0c;编写并执行SQL语句如同日常饮食中的一餐一饭&#xff0c;再寻常不过。然而&#xff0c;在使用命令行或黑屏客户端处理SQL时&#xff0c;常会遇到编写难、错误排查缓慢以及查询结果可读性不佳等难题&#xff0c;因此&#xff0c;图形…

redis和mongodb等对比分析

Redis 和 MongoDB 都是非常流行的 NoSQL 数据库,它们在数据存储模型、性能、扩展性等方面有很大的差异。下面是 Redis 和 MongoDB 的对比分析: 1. 数据模型 Redis: 键值存储:Redis 是一个内存数据结构存储,它支持多种数据类型,如字符串、哈希、列表、集合、有序集合等。…