Matlab 深度学习工具箱 案例学习与测试————求二阶微分方程

devtools/2024/11/22 14:53:49/
clc
clear% 定义输入变量
x = linspace(0,2,10000)';% 定义网络的层参数
inputSize = 1;
layers = [featureInputLayer(inputSize,Normalization="none")fullyConnectedLayer(10)sigmoidLayerfullyConnectedLayer(1)sigmoidLayer];
% 创建网络
net = dlnetwork(layers);% 训练轮数
numEpochs = 15;
% 每个Batch的数据个数
miniBatchSize = 100;

% SGDM优化方法设置的参数
initialLearnRate = 0.5;
learnRateDropFactor = 0.5;
learnRateDropPeriod = 5;
momentum = 0.9;
velocity = [];

% 损失函数里面考虑初始条件的系数
icCoeff = 7;% ArrayDatastore
ads = arrayDatastore(x,IterationDimension=1);
% 创建一个用于处理管理学习>深度学习数据的对象
mbq = minibatchqueue(ads, ...MiniBatchSize=miniBatchSize, ...PartialMiniBatch="discard", ...MiniBatchFormat="BC");% 用于迭代过程监控
numObservationsTrain = numel(x);
numIterationsPerEpoch = floor(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;% 创建监控对象 
% 由于计时器在您创建监控器对象时启动,因此请确保在靠近训练循环的位置创建对象。
monitor = trainingProgressMonitor( ...Metrics="LogLoss", ...Info=["Epoch" "LearnRate"], ...XLabel="Iteration");% Train the network using a custom training loop
epoch = 0;
iteration = 0;
learnRate = initialLearnRate;
start = tic;% Loop over epochs.
while epoch < numEpochs  && ~monitor.Stopepoch = epoch + 1;% Shuffle data,打乱数据.mbq.shuffle% Loop over mini-batches.while hasdata(mbq) && ~monitor.Stopiteration = iteration + 1;% Read mini-batch of data.X = next(mbq);% Evaluate the model gradients and loss using dlfeval and the modelLoss function.[loss,gradients] = dlfeval(@modelLoss, net, X, icCoeff);% Update network parameters using the SGDM optimizer.[net,velocity] = sgdmupdate(net,gradients,velocity,learnRate,momentum);% Update the training progress monitor.recordMetrics(monitor,iteration,LogLoss=log(loss));updateInfo(monitor,Epoch=epoch,LearnRate=learnRate);monitor.Progress = 100 * iteration/numIterations;end% Reduce the learning rate.if mod(epoch,learnRateDropPeriod)==0learnRate = learnRate*learnRateDropFactor;end
endxTest = linspace(0,4,1000)';yModel = minibatchpredict(net,xTest);yAnalytic = exp(-xTest.^2);figure;
plot(xTest,yAnalytic,"-")
hold on
plot(xTest,yModel,"--")
legend("Analytic","Model")

学习>深度学习中,被求导的对象(样本/输入)一般是多元的(向量x),绝大多数情况是标量y对向量x进行求导,很少向量y对向量x进行求导,否则就会得到复杂的微分矩阵。所以经常把一个样本看做一个整体,它包含多个变量(属性),对其所有属性求导后再加和,就得到了这个样本的偏导数之和。

% 损失函数
function [loss,gradients] = modelLoss(net, X, icCoeff)% 前向传播计算y = forward(net,X);% Evaluate the gradient of y with respect to x. % Since another derivative will be taken, set EnableHigherDerivatives to true.dy = dlgradient(sum(y,"all"),X,EnableHigherDerivatives=true);% Define ODE loss.eq = dy + 2*y.*X;% Define initial condition loss.ic = forward(net,dlarray(0,"CB")) - 1;% Specify the loss as a weighted sum of the ODE loss and the initial condition loss.loss = mean(eq.^2,"all") + icCoeff * ic.^2;% Evaluate model gradients.gradients = dlgradient(loss, net.Learnables);end

http://www.ppmy.cn/devtools/136035.html

相关文章

“人工智能+高职”:VR虚拟仿真实训室的发展前景

在当今科技日新月异的时代&#xff0c;人工智能&#xff08;AI&#xff09;与虚拟现实&#xff08;VR&#xff09;技术的融合正逐步改变着各行各业&#xff0c;教育领域也不例外。特别是在高等职业教育&#xff08;简称“高职”&#xff09;体系中&#xff0c;VR虚拟仿真实训室…

2024年亚太地区数学建模C题完整思路

题目 随着人们消费理念的发展&#xff0c;宠物行业作为一个新兴产业&#xff0c;由于经济的快速发展和人均收入的提高&#xff0c;正在全球范围内逐渐积聚力量。1992年&#xff0c;中国小动物保护协会成立&#xff1b;1993年&#xff0c;皇家宠物食品&#xff08;Royal Canin&…

Pycharm

Pycharm PycharmPycharm汉化Pycharm基本设置 Pycharm PyCharm是一种Python IDE&#xff08;Integrated Development Environment&#xff0c;集成开发环境&#xff09;&#xff0c;带有一整套可以帮助用户在使用Python语言开发时提高其效率的工具&#xff0c;比如调试、语法高…

FileProvider高版本使用,跨进程传输文件

高版本的android对文件权限的管控抓的很严格,理论上两个应用之间的文件传递现在都应该是用FileProvider去实现,这篇博客来一起了解下它的实现原理。 首先我们要明确一点,FileProvider就是一个ContentProvider,所以需要在AndroidManifest.xml里面对它进行声明: <provideran…

图论之最小生成树计数(最小生成树的应用)

题目 2401: 信息学奥赛一本通T1492-最小生成树计数 时间限制: 2s 内存限制: 192MB 提交: 18 解决: 8 题目描述 原题来自&#xff1a;JSOI 2008 现在给出了一个简单无向加权图。你不满足于求出这个图的最小生成树&#xff0c;而希望知道这个图中有多少个不同的最小生成树。&…

深入计算机语言之C++:STL之vector的模拟实现

&#x1f511;&#x1f511;博客主页&#xff1a;阿客不是客 &#x1f353;&#x1f353;系列专栏&#xff1a;从C语言到C语言的渐深学习 欢迎来到泊舟小课堂 &#x1f618;博客制作不易欢迎各位&#x1f44d;点赞⭐收藏➕关注 ​ 一、实现基本框架 1.1 结构的定义 &#x1f…

安装textlive 2024

安装textlive 2024 Texlive Texlive 官方虽然不推荐使用镜像文件安装&#xff0c;但是奈何校园网它限速&#xff0c;下载6.4G需要7h&#xff0c;所以采用iso安装方式&#xff0c;这样只需要联网下载一下东西&#xff0c;或者是直接拷贝安装。 据我推测&#xff0c;iso镜像解压…

Scala中的Array

Array:数组 可修改的&#xff1a;ArrayBuffer 不可修改的&#xff1a;Array 需要导入包 import scala.collection.mutable.ArrayBuffer 可修改的&#xff1a; ArrayBuffer def main(args: Array[String]): Unit {//1.新建val arr1ArrayBuffer(1,2,3)//2.添加arr1 4arr1.in…