用 Python 从零开始创建神经网络(七):梯度下降(Gradient Descent)/导数(Derivatives)

news/2024/11/17 4:15:14/

梯度下降(Gradient Descent)/导数(Derivatives)

  • 引言
  • 1. 参数对输出的影响
  • 2. 斜率(The Slope)
  • 3. 数值导数(The Numerical Derivative)
  • 4. 解析导数(The Analytical Derivative)
  • 5. 总结
    • 规则:

引言

随机更改并搜索最佳权重和偏差并未取得成功,主要原因是权重和偏差的可能组合是无限的,我们需要比单纯的运气更智能的方法才能取得任何成功。每个权重和偏差对损失的影响程度也不同——这种影响取决于参数本身以及当前样本,这些样本是第一层的输入。然后这些输入值被权重相乘,所以输入数据影响神经元的输出,并影响权重对损失的影响。相同的原理适用于下一层的偏差和参数,它们以前一层的输出作为输入。这意味着对输出值的影响取决于参数以及样本——这就是为什么我们要分别计算每个样本的损失值。最后,权重或偏差如何影响总体损失的函数并不一定是线性的。为了知道如何调整权重和偏差,我们首先需要了解它们对损失的影响。

需要注意的一个概念是,我们提到权重和偏差及其对损失函数的影响。然而,损失函数中并不包含权重或偏差。这个函数的输入是模型的输出,而神经元的权重和偏差影响这个输出。因此,即使我们是从模型的输出计算损失,而不是权重/偏差,这些权重和偏差直接影响了损失。

在接下来的章节中,我们将通过解释偏导数、梯度、梯度下降和反向传播来详细描述这一过程。基本上,我们将计算每一个单独的权重和偏差如何改变损失值(它对此有多大影响),给定一个样本(因为每个样本产生一个单独的输出,因此也产生一个单独的损失值),以及如何改变这个权重或偏差以减少损失值。记住——我们的目标是减少损失,我们将通过使用梯度下降来实现这一点。另一方面,梯度是偏导数计算的结果,我们将使用链式法则通过反向传播来更新所有的权重和偏差。如果这听起来还没有多大意义,不用担心;我们将在这一章和接下来的章节中解释所有这些术语以及如何执行这些操作。

要理解偏导数,我们需要从导数开始,它们是偏导数的一种特殊情况——从只取单个参数的函数中计算得出。


1. 参数对输出的影响

让我们从一个简单的函数开始,探索“影响”的含义。

一个非常简单的函数 y = 2 x y = 2x y=2x,其中 x x x为输入:

python">def f(x):return 2*x

现在让我们围绕这个函数编写一些代码来可视化数据 —— 我们将导入NumPyMatplotlib,创建一个从0到4的5个输入值的数组,计算每个输入值的函数输出,并将结果以连续点之间的线条形式绘制出来。这些点的坐标是输入 x x x和函数输出 y y y

python">import matplotlib.pyplot as plt
import numpy as npdef f(x):return 2*xx = np.array(range(5))
y = f(x)print(x)
print(y)plt.plot(x, y)
plt.show()
python">>>>
[0 1 2 3 4]
[0 2 4 6 8]

在这里插入图片描述


2. 斜率(The Slope)

这看起来像是 f ( x ) = 2 x f(x) = 2x f(x)=2x函数的输出,这是一条直线。你如何定义 x x x y y y的影响呢?有些人会说,“ y y y x x x的两倍”。另一种描述这种线性函数影响的方式来自代数学:斜率。你可能还记得学校里学过的一个短语“纵横比”(即“rise over run”)。一条直线的斜率是:

在这里插入图片描述
斜率是 y y y的变化量除以 x x x的变化量,或者用数学语言来说 —— Δ y \Delta y Δy除以 Δ x \Delta x Δx。那么 f ( x ) = 2 x f(x) = 2x f(x)=2x的斜率是多少呢?

要计算斜率,首先我们需要在函数图形上取任意两点,并对它们进行减法运算以计算变化量。减去这些点意味着分别减去它们的 x x x y y y维度。 y y y的变化量除以 x x x的变化量得到的结果就是斜率:

在这里插入图片描述
接着写代码,我们将所有 x x x值保存在一个单维度的NumPy数组 x x x中,所有结果保存在另一个单维度数组 y y y中。为了执行相同的操作,我们将取 x [ 0 ] x[0] x[0] y [ 0 ] y[0] y[0]作为第一个点,然后取 x [ 1 ] x[1] x[1] y [ 1 ] y[1] y[1]作为第二个点。现在我们可以计算它们之间的斜率:

python">import matplotlib.pyplot as plt
import numpy as npdef f(x):return 2*xx = np.array(range(5))
y = f(x)plt.plot(x, y)
plt.show()print((y[1]-y[0]) / (x[1]-x[0]))
python">>>>
2.0

这条直线的斜率是 2 2 2并不令人意外。我们可以说 x x x y y y的影响度量是 2 2 2。我们可以以相同的方式计算任何线性函数的斜率,包括那些不那么明显的线性函数。

那么对于一个非线性函数如 f ( x ) = 2 x 2 f(x) = 2x^2 f(x)=2x2呢?

python">def f(x):return 2*x**2

此函数创建一个不形成直线的图形:
在这里插入图片描述

我们可以测量这条曲线的斜率吗?根据我们选择使用的 2 点,我们将测量不同的斜率:

python">y = f(x) # Calculate function outputs for new functionprint(x)
print(y)
python">>>>
[0 1 2 3 4]
[ 0 2 8 18 32]

现在讨论第一点:

python">print((y[1]-y[0]) / (x[1]-x[0]))
python">>>>
2

另一个:

python">print((y[3]-y[2]) / (x[3]-x[2]))
python">>>>
10

在这里插入图片描述

代码的可视化: https://nnfs.io/bro

我们如何衡量 x x x y y y在这个非线性函数中的影响呢?微积分提出,我们可以通过测量 x x x点(对于函数的一个特定输入值)的切线斜率来衡量,这给出了瞬时斜率(该点的斜率),即导数。切线是通过在曲线上的两个“无限接近”的点之间画一条线来创建的,但这条曲线必须在导数点处是可微分的。这意味着它必须是连续且平滑的(我们无法计算在我们可以描述为“尖角”的东西上的斜率,因为它包含无数个斜率)。然后,因为这是一条曲线,所以没有单一的斜率。斜率取决于我们在哪里测量它。举一个直接的例子,我们可以通过使用 x x x点和另一个也在 x x x点但加上一个非常小的增量,例如0.0001的点来近似这个函数在 x x x处的导数。这个数字是一个常见的选择,因为它不会引入太大的误差(在估计导数时)或导致整个表达式在数值上不稳定(由于浮点数分辨率, Δ x \Delta x Δx可能会四舍五入到0)。这让我们能够像以前一样计算斜率,但是在两个非常接近的点上进行,从而得到 x x x处斜率的良好近似:

python">import matplotlib.pyplot as plt
import numpy as npdef f(x):return 2*x**2p2_delta = 0.0001x1 = 1
x2 = x1 + p2_delta # add deltay1 = f(x1) # result at the derivation point
y2 = f(x2) # result at the other, close pointapproximate_derivative = (y2-y1)/(x2-x1)print(approximate_derivative)
python">>>>
4.0001999999987845

我们很快就会了解到,当 x = 1 x=1 x=1时, 2 x 2 2x^2 2x2的导数应该恰好是4。我们看到的差异(约为4.0002)来自于计算切线的方法。我们选择了一个足够小的 δ \delta δ,以尽可能准确地近似导数,但又足够大以防止四舍五入误差。

更详细地说,一个无限小的 δ \delta δ值将近似一个准确的导数;然而, δ \delta δ值需要在数值上稳定,意味着,我们的 δ \delta δ不能超过Python浮点精度的限制(不能太小,否则可能会被四舍五入为0,而我们知道,除以0是“非法”的)。因此,我们的解决方案受限于估计导数和保持数值稳定之间,因此引入了这个小但可见的误差。


3. 数值导数(The Numerical Derivative)

这种计算导数的方法称为数值微分——使用两个无限接近的点计算切线的斜率,或者像代码解决方案中那样——计算由两个“足够接近”的点构成的切线的斜率。我们可以通过以下内容来可视化为什么我们要在两个接近的点上执行这个操作:

在这里插入图片描述
在这里插入图片描述

代码的可视化: https://nnfs.io/cat

我们可以看到,这两个点越接近,切线就显得越正确。

继续进行数值微分,让我们可视化切线并观察它们如何随着计算位置的不同而变化。首先,我们将使用Numpyarange()函数使这个函数的图形更加细致,从而允许我们以更小的步骤绘图。np.arange()函数接收起始点、终点和步长参数,允许我们每次取0.001这样的小步长:

python">import matplotlib.pyplot as plt
import numpy as npdef f(x):return 2*x**2# np.arange(start, stop, step) to give us smoother line
x = np.arange(0, 5, 0.001)
y = f(x)plt.plot(x, y)
plt.show()

在这里插入图片描述

为了绘制这些切线,我们将推导出某一点处切线的函数,并在该点的图形上绘制切线。直线的函数是 y = m x + b y = mx + b y=mx+b。其中 m m m是我们已经计算出的斜率或近似导数, x x x是输入,这使得 b b b,即 y y y截距,需要我们来计算。斜率保持不变,但目前,您可以通过 y y y截距“上移”或“下移”线。我们已经知道 x x x m m m,但 b b b仍然未知。假设为了图形的目的令 m = 1 m=1 m=1,看看这实际意味着什么:

在这里插入图片描述

代码的可视化: https://nnfs.io/but

要计算 b b b,公式是 b = y − m x b = y - mx b=ymx

在这里插入图片描述

到目前为止,我们已经使用了两个点——我们想要计算导数的点和与其“足够接近”的点来计算导数的近似值。现在,根据上述 b b b的方程,导数的近似值,以及同一个“足够接近”的点(具体来说是它的 x x x y y y坐标),我们可以将它们代入方程中,得到该导数点处切线的 y y y截距。使用代码:

python">import matplotlib.pyplot as plt
import numpy as npdef f(x):return 2*x**2# np.arange(start, stop, step) to give us smoother line
x = np.arange(0, 5, 0.001)
y = f(x)plt.plot(x, y)# The point and the "close enough" point
p2_delta = 0.0001
x1 = 2
x2 = x1+p2_deltay1 = f(x1)
y2 = f(x2)print((x1, y1), (x2, y2))# Derivative approximation and y-intercept for the tangent line
approximate_derivative = (y2-y1)/(x2-x1)
b = y2 - approximate_derivative*x2# We put the tangent line calculation into a function so we can call
# it multiple times for different values of x
# approximate_derivative and b are constant for given function
# thus calculated once above this function
def tangent_line(x):return approximate_derivative*x + b# plotting the tangent line
# +/- 0.9 to draw the tangent line on our graph
# then we calculate the y for given x using the tangent line function
# Matplotlib will draw a line for us through these points
to_plot = [x1-0.9, x1, x1+0.9]
plt.plot(to_plot, [tangent_line(i) for i in to_plot])print('Approximate derivative for f(x)', f'where x = {x1} is {approximate_derivative}')plt.show()
python">>>>
(2, 8) (2.0001, 8.000800020000002)
Approximate derivative for f(x) where x = 2 is 8.000199999998785

在这里插入图片描述

橙色的线是函数 f ( x ) = 2 x 2 f(x) = 2x^2 f(x)=2x2 x = 2 x=2 x=2处的近似切线。我们为什么要关心这个?你很快会发现,我们只关心这条切线的斜率,但是可视化和理解切线都非常重要。我们关心切线的斜率,因为它告诉我们 x x x在特定点对这个函数的影响,被称为瞬时变化率。我们将使用这个概念来确定一个特定的权重或偏置对给定样本的总损失函数的影响。现在,对于不同的 x x x值,我们可以观察到对函数的不同影响。我们可以继续之前的代码,查看不同输入( x x x)的切线 - 我们将代码的一部分放在一个循环中,遍历示例 x x x值,并绘制多条切线:

python">import matplotlib.pyplot as plt
import numpy as npdef f(x):return 2*x**2
# np.arange(start, stop, step) to give us a smoother curve
x = np.array(np.arange(0,5,0.001))
y = f(x)plt.plot(x, y)colors = ['k','g','r','b','c']def approximate_tangent_line(x, approximate_derivative):return (approximate_derivative*x) + bfor i in range(5):p2_delta = 0.0001x1 = ix2 = x1+p2_deltay1 = f(x1)y2 = f(x2)print((x1, y1), (x2, y2))approximate_derivative = (y2-y1)/(x2-x1)b = y2-(approximate_derivative*x2)to_plot = [x1-0.9, x1, x1+0.9]plt.scatter(x1, y1, c=colors[i])plt.plot([point for point in to_plot], [approximate_tangent_line(point, approximate_derivative) for point in to_plot], c=colors[i])print('Approximate derivative for f(x)', f'where x = {x1} is {approximate_derivative}')plt.show()
python">>>>
(0, 0) (0.0001, 2e-08)
Approximate derivative for f(x) where x = 0 is 0.00019999999999999998
(1, 2) (1.0001, 2.00040002)
Approximate derivative for f(x) where x = 1 is 4.0001999999987845
(2, 8) (2.0001, 8.000800020000002)
Approximate derivative for f(x) where x = 2 is 8.000199999998785
(3, 18) (3.0001, 18.001200020000002)
Approximate derivative for f(x) where x = 3 is 12.000199999998785
(4, 32) (4.0001, 32.00160002)
Approximate derivative for f(x) where x = 4 is 16.000200000016548

在这里插入图片描述

对于这个简单的函数 f ( x ) = 2 x 2 f(x) = 2x^2 f(x)=2x2,我们通过近似导数(即切线的斜率)并没有支付高昂的代价,并且得到了足够接近我们需求的值。

问题在于,我们神经网络中实际使用的函数并不简单。损失函数包含了所有的层、权重和偏置——这是一个在多个维度上运作的绝对庞大的函数!使用数值微分计算导数需要对单个参数更新进行多次前向传递(我们将在第10章讨论参数更新)。我们需要执行前向传递作为参考,然后通过增量值更新单个参数,并再次通过我们的模型执行前向传递以查看损失值的变化。接下来,我们需要计算导数并恢复我们为这次计算所做的参数更改。我们必须对每个权重和偏置以及每个样本重复这一过程,这将非常耗时。我们还可以将这种方法视为强行计算导数的方法。再次强调,正如我们快速涵盖了许多术语,导数是作为单个参数输入的函数的切线的斜率。我们将利用这种能力计算损失函数在每个权重和偏置点的斜率——这将引导我们进入多变量函数,这是一个涉及多个参数的函数,是下一章的主题——偏导数。


4. 解析导数(The Analytical Derivative)

现在我们对什么是导数有了更好的理解,如何计算数值导数(也称为通用导数),以及为什么这种方法不适合我们,我们可以继续探讨解析导数,这是我们将在代码中实现的实际导数解法。

在数学中,解决问题通常有两种方法:数值方法和解析方法。数值解决方法涉及计算一个数字来找到解决方案,就像上面的近似导数方法一样。数值解也是一种近似。另一方面,解析方法提供了精确且在计算方面更快的解决方案。然而,正如我们将很快学到的,识别给定函数的导数的解析解在复杂性上会有所不同,而数值方法的复杂度从不增加——它总是调用两次方法,用两个输入来计算某一点的近似导数。一些解析解非常明显,有些可以通过简单的规则计算,而一些复杂的函数可以分解成更简单的部分,并使用所谓的链式法则进行计算。我们可以利用已经证明的某些函数的导数解,而其他一些函数——像我们的损失函数——可以通过上述方法的组合来解决。

要使用解析方法计算函数的导数,我们可以将它们分解为简单的基本函数,找到这些函数的导数,然后应用链式法则(我们很快会解释这一点)来得到完整的导数。为了开始建立直觉,让我们从简单的函数及其各自的导数开始。简单常数函数的导数:

在这里插入图片描述

在这里插入图片描述

代码的可视化: https://nnfs.io/cow

当计算函数的导数时,回想一下导数可以被解释为斜率。在这个例子中,这个函数的结果是一条水平线,因为任何 x x x的输出值都是1:

通过观察,很明显导数等于0,因为从一个 x x x值到任何其他 x x x值都没有变化(即,没有斜率)。

到目前为止,我们通过取一个单一参数 x x x来计算函数的导数,在每个例子中我们都是这样做的。这在计算偏导数时会有所改变,因为偏导数涉及到具有多个参数的函数,并且我们将一次只对其中一个参数计算导数。目前,对于导数,它总是相对于单一参数。为了表示导数,我们可以使用撇号记法,例如对于函数 f ( x ) f(x) f(x),我们加上一个撇号 ( ′ ) (') () f ′ ( x ) f'(x) f(x)。对于我们的例子, f ( x ) = 1 f(x) = 1 f(x)=1,其导数 f ′ ( x ) = 0 f'(x) = 0 f(x)=0。我们还可以使用的另一种记法称为莱布尼茨记法——依赖于撇号记法和用莱布尼茨记法书写导数的多种方式如下:


这些记法的含义相同 —— 函数的导数(相对于 x x x)。在接下来的例子中,我们将使用这两种记法,因为有时使用一种记法或另一种记法更方便。我们也可以在同一个等式中同时使用这两种记法。

5. 总结

规则:

函数的常数倍的导数等于该函数导数的常数倍:

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

我们在这里使用 x x x的值而不是整个函数 f ( x ) f(x) f(x),因为整个函数的导数的计算方式有些不同。我们将在下一章中解释这个概念,并结合链式法则进行讲解。

既然我们已经了解了什么是导数以及如何解析地计算它们,稍后我们将在代码中实现,我们可以更进一步,在下一章中讨论偏导数。

本章的章节代码、更多资源和勘误表:https://nnfs.io/ch7


http://www.ppmy.cn/news/1547620.html

相关文章

七:如何用Chrome的Network面板分析HTTP报文

在Web开发和调试中,分析HTTP请求和响应报文可以帮助开发者了解浏览器和服务器之间的通信细节,定位并解决各种问题。Chrome浏览器的Network(网络)面板是一个强大的开发工具,它可以详细展示HTTP请求的各个方面,包括请求方法、状态码、头部信息、负载数据等。本文将介绍如何…

Mock.js生成随机数据,拦截 Ajax 请求

Mock.js 是一个用于模拟数据的 JavaScript 库,特别适合用于前端开发过程中生成假数据进行接口测试。它可以拦截 Ajax 请求并生成随机数据,还可以模拟服务器的响应来加速前端开发。 一、安装 Mock.js 可以通过以下几种方式引入 Mock.js: CDN…

1909. 删除一个元素使数组严格递增【简单】

题目描述 给你一个下标从 0 开始的整数数组 nums ,如果 恰好 删除 一个 元素后,数组 严格递增 ,那么请你返回 true ,否则返回 false 。如果数组本身已经是严格递增的,请你也返回 true 。 数组 nums 是 严格递增 的定…

tensorflow案例6--基于VGG16的猫狗识别(准确率99.8%+),以及tqdm、train_on_batch的简介

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 前言 本次还是学习API和如何搭建神经网络为主,这一次用VGG16去对猫狗分类,效果还是很好的,达到了99.8% 文章目录 1、tqdm…

IDEA部署AI代写插件

前言 Hello大家好,当下是AI盛行的时代,好多好多东西在AI大模型的趋势下都变得非常的简单。 比如之前想画一幅风景画得先去采风,然后写实什么的,现在你只需描述出你想要的效果AI就能够根据你的描述在几分钟之内画出一幅你想要的风景…

Unity简单漫游摄像机

可以wasd漫游场景,q/e来上升下降,可以调用TeleportAndLookAtTarget来传送到一个Transform附近并注视它,挂到摄像机上就能使用. using UnityEngine;public class CameraController : MonoBehaviour {public float moveSpeed 5f; // 相机移动速度public float sprintMultiplier…

Unity学习笔记(4):人物和基本组件

文章目录 前言开发环境新增角色添加组件RigidBody 2D全局项目设置Edit 给地图添加碰撞体 总结 前言 今天不加班,有空闲时间。争取一天学一课,养成习惯 开发环境 Unity 6windows 11vs studio 2022Unity2022.2 最新教程《勇士传说》入门到进阶&#xff…

基于Spring Boot的电子商务系统设计

5 系统实现 系统实现部分就是将系统分析,系统设计部分的内容通过编码进行功能实现,以一个实际应用系统的形式展示系统分析与系统设计的结果。前面提到的系统分析,系统设计最主要还是进行功能,系统操作逻辑的设计,也包括…