神经网络的数学——一个完整的例子

ops/2024/11/29 4:55:42/

神经网络是一种人工智能方法,它教导计算机以类似于人脑的方式处理数据。神经网络通过输入多个数据实例、预测输出、找出实际答案与机器答案之间的误差,然后微调权重以减少此误差来进行学习。

虽然神经网络看起来非常复杂,但它实际上是线性代数和多元微积分的巧妙运用。本文旨在全面介绍破坏神经网络的数学原理。

假设和预备知识

神经网络需要对大学水平的微积分和线性代数有扎实的理解。在可汗学院网站上可以找到很好的复习资料(链接在上一句中)。本例中必不可少的算法是梯度下降,本视频对此进行了很好的解释。

对于与神经网络更相关的课程,Adam Dhalla 的这个视频仅教您此示例所需的微积分和线性代数的必要领域。

神经网络基础

我们将使用的示例是:

通常,输入层(绿色)是来自数据集的输入变量输出层(红色)是神经网络预测值。在隐藏层和输出层中,对每个节点进行加权和(用s表示),然后应用激活函数(用a表示),根据所需的激活函数对值进行归一化。

将数据从输入端通过网络馈送到输出端的过程称为前向传播。观察前向传播的错误率并将错误反馈回网络以微调神经网络权重的过程称为反向传播。我们在反向传播之前先进行前向传播。

前向传播

注意:在这个例子中,我使用sigmoid函数作为激活函数(激活函数用作将输入映射到一定范围内 - 对于 sigmoid 来说,范围是 (0, 1))。

隐藏层

隐藏层 1:

隐藏层 2:

隐藏层 3:

输出层

输出层 1:

输出层 2:

均方误差 (MSE) 计算

均方误差是预期输出和实际输出之间差异的度量。我们正在寻找较低的 MSE 分数,这表明模型与数据的拟合度更高。我们将使用梯度下降法来降低该值。

反向传播

现在已经计算出预测值,神经网络需要根据预测误差调整其权重。这是通过反向传播完成的。

对于此示例,考虑学习率为 0.1

反向传播背后的一般数学思想是应用链式法则来找到误差函数随权重变化的变化。以权重 7 为例:

所有三个部分方程均可从我们的工作中推导出来。

首先,

第二,

最后,

因此,把这三个术语放在一起,

该公式可以适用于连接隐藏层和输出层的所有权重。

注意:作者通常会使用 delta 来写方程:δ₀₁= (a₀₁−expected₁) × a₀₁ × (1−a₀₁),因此方程可以写成 ∂E₀₁ / ∂w₇ = δ₀₁ × aₕ₁

现在我们得到了误差函数的梯度。

我们想应用梯度下降来获得权重 w₇ 的新值。新的 w₇(我们可以将其符号化为 w₇')可以通过从 w₇ 中减去学习率乘以梯度来获得。

一般来说,对于输出神经元:

输出层

现在,应用示例中的实数来查找 w₇ 到 w₁₂ 的新值

输出层 1:

输出层 2:

隐藏层(衍生)

找到一种方法来优化隐藏层权重具有更大的推导量——本节中的任何内容都与计算无关,因此如果需要,可以随意跳过此部分。

考虑更新 w₁ 的权重——原则上,更新任何权重在围绕偏微分旋转方面都会具有相同风格的公式。

然而这一次,我们离输出神经元更远了——因此,为了找到这个方程右侧各个分量的值,还需要进行更多的“链接”……

对于一阶导数:

在哪里:

现在,由于我们之前已经计算了 δ₀₁ 和 δ₀₂(参见本文输出层部分所做的计算),我们可以将这些增量的值代入方程中。

因此,加权和相对于前一层的神经元的导数本质上就是相应的权重。

现在,用这些值代替部分误差项:

∂aₕ₁ / ∂sₕ₁ 的值只是 S 型函数的导数

∂sₕ₁ / ∂w₁ 的值是前一层神经元的输出(在本例中,由于只有一个隐藏层,所以是输入层神经元)

综上所述:

我希望您能看到这些步骤中发生了什么——可以进行类似的工作过程来找到所有权重的公式(我不会展示)。

但本质上,要找到更新权重的值,首先计算权重输出神经元的增量,然后从增量中减去旧权重,乘以增量,再乘以权重输入神经元的先前值。

如果这很难理解,那么下面的计算可能会帮助您了解数字上发生的情况。

隐藏层(计算)

先前计算的 delta 值:
δ₀₁ = -0.0984
δ₀₂ = 0.1479

隐藏层 1:

隐藏层 2:

隐藏层 3:

完成了!

具有更新权重的神经网络

结束语

以下是 3 层神经网络前向和反向传播的完整示例。

通常,神经网络在多个数据实例上进行训练,也可以进行多次迭代训练(我们称之为时期)。这样做会根据实例逐渐增加/减少权重,直到神经网络针对一组实例进行优化。

这个过程非常费力,而且数学运算量很大——幸好这就是我们用计算机模拟所有这些工作的原因。像PyTorch这样的库抽象了许多数学复杂性,绝对应该用于任何类型的模型训练。

尽管如此,完整的数学演练肯定有助于强化实施该模型时所需的理解。


http://www.ppmy.cn/ops/137566.html

相关文章

【Elasticsearch入门到落地】2、正向索引和倒排索引

接上篇《1、初识Elasticsearch》 上一篇我们学习了什么是Elasticsearch,以及Elastic stack(ELK)技术栈介绍。本篇我们来什么是正向索引和倒排索引,这是了解Elasticsearch底层架构的核心。 上一篇我们学习到,Elasticsearch的底层是由Lucene实…

Linux——环境变量

前言:大佬写博客给别人看,菜鸟写博客给自己看,我是菜鸟。 感言:每天的认知都在被刷新。 1:基本概念 环境变量(environment variables):⼀般是指在操作系统中⽤来指定操作系统运⾏环境的⼀些参数 2&#xf…

如何通过ChatGPT提高自己的编程水平

在编程学习的过程中,开发者往往会遇到各种各样的技术难题和学习瓶颈。传统的学习方法依赖书籍、教程、视频等,但随着技术的不断发展,AI助手的崛起为编程学习带来了全新的机遇。ChatGPT,作为一种强大的自然语言处理工具&#xff0c…

Spring Boot整合Redis Stack构建本地向量数据库相似性查询

Spring Boot整合Redis Stack构建本地向量数据库相似性查询 在微服务架构中,数据的高效存储与快速查询是至关重要的。Redis作为一个高性能的内存数据结构存储系统,不仅可以用作缓存、消息代理,还可以扩展为向量数据库,实现高效的相…

Qt桌面应用开发 第九天(综合项目一 飞翔的鸟)

目录 1.鸟类创建 2.鸟动画实现 3.鼠标拖拽 4.自动移动 5.右键菜单 6.窗口透明化 项目需求: 实现思路: 创建项目导入资源鸟类创建鸟动画实现鼠标拖拽实现自动移动右键菜单窗口透明化 1.鸟类创建 ①鸟类中包含鸟图片、鸟图片的最小值下标和最大值…

【04】Selenium+Python 手动添加Cookie免登录(实例)

一、什么是Cookie? Cookie 是一种由服务器创建并保存在用户浏览器中的小型数据文件。它用于存储用户的相关信息,以便在后续访问同一网站时可以快速检索这些信息。Cookie 主要用于以下几个方面: 1.状态管理: Cookie 可以保存用户…

【简单好抄保姆级教学】javascript调用本地exe程序(谷歌,edge,百度,主流浏览器都可以使用....)

javascript调用本地exe程序 详细操作步骤结果 详细操作步骤 在本地创建一个txt文件依次输入 1.指明所使用注册表编程器版本 Windows Registry Editor Version 5.00这是脚本的第一行,指明了所使用的注册表编辑器版本。这是必需的,以确保脚本能够被正确解…

数据库和缓存的数据一致性 -20241124

问题描述 一致性 缓存中有数据,缓存的数据值数据库中的值缓存中本没有数据,数据库中的值最新值(有请求查询数据库时,会将数据写入缓存,则变为上面的“一致”状态) “数据不一致”: 缓存的数据值…