神经网络的公式推导与代码实现(论文复现)

news/2024/12/22 20:49:48/

神经网络的公式推导与代码实现(论文复现)

本文所涉及所有资源均在传知代码平台可获取

概述

本文将详细推导一个简单的神经网络模型的正向传播、反向传播、参数更新等过程,并将通过一个手写数字识别的例子,使用python手写和pytorch分别实现,能够让读者深刻地理解神经网络的具体参数更新训练的工作流程,文末将包含数据+代码+PPT。

这些内容是基于神经网络和机器学习的通用知识,正向传播和反向传播,如今几乎所有的深度学习模型的训练都是基于这样相同或者相似的方法进行训练的,有助于帮助我们更加深入的理解深度学习模型。

引言

多层感知机(Multilayer Perceptron,简称MLP)是神经网络的一种。MLP是一种前馈神经网络,它包含一个或多个隐藏层,以及非线性激活函数,这使得MLP能够学习和模拟复杂的非线性关系。MLP是最基础也是最广泛研究的神经网络类型之一,本文将以一个MLP模型来展开。

MLP的结构通常如下:

输入层:接收外部输入数据。

隐藏层:一个或多个隐藏层,每层包含多个神经元。隐藏层负责从输入数据中提取特征并进行初步的非线性变换。

输出层:输出网络的预测结果,对于分类问题,输出层通常使用softmax激活函数进行多类分类。

MLP的训练过程通常包括以下几个步骤:

前向传播:输入数据通过网络,通过每个神经元的加权和和激活函数,最终得到输出。
计算损失:使用损失函数(如均方误差、交叉熵等)计算网络输出与真实标签之间的差异。

反向传播:根据损失函数的梯度,计算每一层的权重对损失的贡献,即梯度。

权重更新:使用梯度下降或其他优化算法(如Adam、RMSprop等)根据梯度更新网络的权重和偏置。

MLP在许多领域都有应用,包括图像识别、语音识别、自然语言处理、游戏AI等。随着深度学习的发展,MLP作为深度神经网络的基础,其结构和训练方法也在不断地被改进和优化。

实际上,几乎所有的深度学习模型中都会有MLP的身影,相当于深度学习模型的骨架,特别是在深度学习模型中最后一步,通常会接个MLP来使得输出的维度符合我们任务的需求,例如我们当前需要要对手写数字识别,那就是一个10分类问题,最后输出可以通过接一个MLP变成10维,每一维代表一个分类,从而顺利地使模型适配我们的任务。

神经网络公式推导

在这里插入图片描述

假设我们有这么一个神经网络,由输入层、一层隐藏层、输出层构成:(这里为了方便,不考虑偏置bias)

在这里插入图片描述

在这里插入图片描述

前向传播(forward)

首先,我们可以试着表示一下y1
如模型图所示可以表示为:

在这里插入图片描述

那么我要表示yj呢?

在这里插入图片描述

其中j=1时,就是y1的表示,j=m时,就是ym的表示。

同理我们可以得到:

在这里插入图片描述

ok表示输出层第k个神经元的预测值,这就是我们需要的输出。
至此,正向传播完毕

反向传播(backward)

光正向传播,我们只能得到模型的预测值,不能更新模型的参数,也就是说,正向传播的时候,模型是不会被更新的。

因为我们得到了模型输出的预测值,并且我们手上有对应的真实值,我们就能够将误差反向传播,更新模型参数。

具体操作怎么操作呢?

首先,我们需要定义误差,即预测值和真实值差了多少,以此来决定模型参数更新的方向和力度。

这里我们采用简单的差的平方的损失函数:

在这里插入图片描述

注意,这里只是更新输出层第k个神经元所反馈的误差。

隐藏层和输出层的权重更新
首先根据已知如下:

输出层预测值ok

在这里插入图片描述

激活函数Sigmoid

在这里插入图片描述

那我们可以试着展开一下Ek

在这里插入图片描述

因为我们现在需要更新的是wjk,因此展开到wjk我们就能有一个比较形象的认识了。

根据梯度下降法可得,我们现在只需要求出

在这里插入图片描述

在这里插入图片描述

接下来我们分别求出:

在这里插入图片描述

在这里插入图片描述

我们先给出激活函数的导数推导过程:

在这里插入图片描述

就是使用复合函数除的求导法则进行求导。我们可以发现sigmoid函数求导之后还是挺好看的。

接下来就是计算两个导数即可。

在这里插入图片描述
在这里插入图片描述

一眼就能看出来了吧,就是别忘了里面的-ok也要导,负号别漏了,然后是

在这里插入图片描述

这个可能会有点困难,但是仔细看看,发现还是很简单的;首先

在这里插入图片描述

在这里插入图片描述

(链式求导法)因此:

在这里插入图片描述

那么这个结果计算起来就比较简单了;既然如此,将结果拼起来就是我们要求的结果了:

在这里插入图片描述

在这里插入图片描述

全是已知的,不就可以更新参数了嘛;因此,加个学习率这层权重更新推导就大功告成了

在这里插入图片描述

输入层和隐藏层的权重更新;如果上面的推导看懂了,下面的推导就非常简单了,无非就是多展开一级,多求一次导数而已;首先(前面已经推到过了)

在这里插入图片描述

那么我们可以将误差再展开一级(接着链导下去):

在这里插入图片描述

那么下面这个就非常直观了

在这里插入图片描述

同样的,我们也分别求出三次的导数,最后拼起来就行了。

在这里插入图片描述

至此分别求出来了,拼起来就是我们要的结果了:

在这里插入图片描述

通过观察,里面全是已知的变量;那么更新公式也就有了:

在这里插入图片描述

数据集介绍

实验数据就是mnist手写数据集

在这里插入图片描述

第一列为label,表示这个图片是什么数字;后面都为图片的像素值,表示图片的数据;模型的输入就是像素值,输出就是预测值,即通过像素预测出是什么数字。

核心代码

其中比较关键的就是那两个参数的更新公式;隐藏层和输出层的权重更新:

在这里插入图片描述

输入层和隐藏层的权重更新:

在这里插入图片描述

数据集+python手写代码+pytorch代码+ppt都在附件里哦

运行结果

在这里插入图片描述

在这里插入图片描述

总结

感觉从推导到代码实现也是一个反复的过程,从推导发现代码写错了,写不出代码了就要去看看推导的过程,这个过程让我对反向传播有了较全面的理解。

我们发现,手写代码运行时间要一分多钟而pytorch其实只要10s不到,毕竟框架,底层优化很多,用起来肯定用框架。

以及二者准确率有一些差距,可能是因为pytorch里使用了交叉熵损失函数,比较适合分类任务;手写的并没有分batch,而是所有数据直接更新参数,但是pytorch里分了batch,分batch能够使得模型训练速度加快(并行允许),也使得模型参数更新的比较平稳。

文章代码资源点击附件获取


http://www.ppmy.cn/news/1527231.html

相关文章

使用vant UI实现时间段选择

需求&#xff1a;选择时间段或者选择日期&#xff0c;时间段不允许跨月&#xff0c;选完开始时间后&#xff0c;结束时间可选 “开始日期~当月最后一天” 格式&#xff1a;2023-01-01~2023-01-23 或者 2023-01-01 这里使用vantUI 示例代码: <van-fieldlabel"日期&quo…

(CS231n课程笔记)深度学习之损失函数详解(SVM loss,Softmax,熵,交叉熵,KL散度)

学完了线性分类&#xff0c;我们要开始对预测结果进行评估&#xff0c;进而优化权重w&#xff0c;提高预测精度&#xff0c;这就要用到损失函数。 损失函数&#xff08;Loss Function&#xff09;是机器学习模型中的一个关键概念&#xff0c;用于衡量模型的预测结果与真实标签…

openCV的python频率域滤波

在OpenCV中实现频率域滤波通常涉及到傅里叶变换(Fourier Transform)和其逆变换(Inverse Fourier Transform)。傅里叶变换是一种将图像从空间域转换到频率域的数学工具,这使得我们可以更容易地在图像的频域内进行操作,如高通滤波、低通滤波等。 下面,我将提供一个使用Py…

linux-L3-linux 复制文件

linux 中要将文件file1.txt复制到目录dir中&#xff0c;可以使用以下命令 cp file1.txt dir/复制文件 cp /path/to/source/file /path/to/destination移动 mv /path/to/source/file /path/to/destination复制文件夹内的文件 cp -a /path/to/source/file /path/to/destinati…

Rust的常量

【图书介绍】《Rust编程与项目实战》-CSDN博客 《Rust编程与项目实战》(朱文伟&#xff0c;李建英)【摘要 书评 试读】- 京东图书 (jd.com) Rust编程与项目实战_夏天又到了的博客-CSDN博客 3.3.1 常量的定义 常量和变量是高级程序设计语言中数据的两种表现形式。这里我们先…

Go语言并发编程之select语句详解

在Go语言的并发模型中,channel是用于在goroutine之间进行通信的主要工具,而select语句则是将多个channel结合在一起的关键机制。通过select语句,开发者可以同时监控多个channel的状态,从而构建更为复杂和灵活的并发逻辑。本文将详细介绍select语句的原理和用法,并通过多个…

62. 圆圈中最后剩下的数字

comments: true difficulty: 简单 edit_url: https://github.com/doocs/leetcode/edit/main/lcof/%E9%9D%A2%E8%AF%95%E9%A2%9862.%20%E5%9C%86%E5%9C%88%E4%B8%AD%E6%9C%80%E5%90%8E%E5%89%A9%E4%B8%8B%E7%9A%84%E6%95%B0%E5%AD%97/README.md 面试题 62. 圆圈中最后剩下的数字…

硬件工程师笔试面试学习汇总——器件篇目录

目录 一、器件篇目录 1、电阻(Resistors) 1.1、 基础 1.2、相关问题 1.3、上拉电阻 1.4、下拉电阻 2、电容(Capacitors) 2.1、基础 2.2、相关问题 3、电感(Inductors) 3.1、基础 3.2、相关问题 4、二极管(Diodes) 4.1、基础 4.2、相关问题 5、三极管 5.1…