用 Python 从零开始创建神经网络(七):梯度下降(Gradient Descent)/导数(Derivatives)

server/2024/11/19 10:38:22/

梯度下降(Gradient Descent)/导数(Derivatives)

  • 引言
  • 1. 参数对输出的影响
  • 2. 斜率(The Slope)
  • 3. 数值导数(The Numerical Derivative)
  • 4. 解析导数(The Analytical Derivative)
  • 5. 总结
    • 规则:

引言

随机更改并搜索最佳权重和偏差并未取得成功,主要原因是权重和偏差的可能组合是无限的,我们需要比单纯的运气更智能的方法才能取得任何成功。每个权重和偏差对损失的影响程度也不同——这种影响取决于参数本身以及当前样本,这些样本是第一层的输入。然后这些输入值被权重相乘,所以输入数据影响神经元的输出,并影响权重对损失的影响。相同的原理适用于下一层的偏差和参数,它们以前一层的输出作为输入。这意味着对输出值的影响取决于参数以及样本——这就是为什么我们要分别计算每个样本的损失值。最后,权重或偏差如何影响总体损失的函数并不一定是线性的。为了知道如何调整权重和偏差,我们首先需要了解它们对损失的影响。

需要注意的一个概念是,我们提到权重和偏差及其对损失函数的影响。然而,损失函数中并不包含权重或偏差。这个函数的输入是模型的输出,而神经元的权重和偏差影响这个输出。因此,即使我们是从模型的输出计算损失,而不是权重/偏差,这些权重和偏差直接影响了损失。

在接下来的章节中,我们将通过解释偏导数、梯度、梯度下降和反向传播来详细描述这一过程。基本上,我们将计算每一个单独的权重和偏差如何改变损失值(它对此有多大影响),给定一个样本(因为每个样本产生一个单独的输出,因此也产生一个单独的损失值),以及如何改变这个权重或偏差以减少损失值。记住——我们的目标是减少损失,我们将通过使用梯度下降来实现这一点。另一方面,梯度是偏导数计算的结果,我们将使用链式法则通过反向传播来更新所有的权重和偏差。如果这听起来还没有多大意义,不用担心;我们将在这一章和接下来的章节中解释所有这些术语以及如何执行这些操作。

要理解偏导数,我们需要从导数开始,它们是偏导数的一种特殊情况——从只取单个参数的函数中计算得出。


1. 参数对输出的影响

让我们从一个简单的函数开始,探索“影响”的含义。

一个非常简单的函数 y = 2 x y = 2x y=2x,其中 x x x为输入:

python">def f(x):return 2*x

现在让我们围绕这个函数编写一些代码来可视化数据 —— 我们将导入NumPyMatplotlib,创建一个从0到4的5个输入值的数组,计算每个输入值的函数输出,并将结果以连续点之间的线条形式绘制出来。这些点的坐标是输入 x x x和函数输出 y y y

python">import matplotlib.pyplot as plt
import numpy as npdef f(x):return 2*xx = np.array(range(5))
y = f(x)print(x)
print(y)plt.plot(x, y)
plt.show()
python">>>>
[0 1 2 3 4]
[0 2 4 6 8]

在这里插入图片描述


2. 斜率(The Slope)

这看起来像是 f ( x ) = 2 x f(x) = 2x f(x)=2x函数的输出,这是一条直线。你如何定义 x x x y y y的影响呢?有些人会说,“ y y y x x x的两倍”。另一种描述这种线性函数影响的方式来自代数学:斜率。你可能还记得学校里学过的一个短语“纵横比”(即“rise over run”)。一条直线的斜率是:

在这里插入图片描述
斜率是 y y y的变化量除以 x x x的变化量,或者用数学语言来说 —— Δ y \Delta y Δy除以 Δ x \Delta x Δx。那么 f ( x ) = 2 x f(x) = 2x f(x)=2x的斜率是多少呢?

要计算斜率,首先我们需要在函数图形上取任意两点,并对它们进行减法运算以计算变化量。减去这些点意味着分别减去它们的 x x x y y y维度。 y y y的变化量除以 x x x的变化量得到的结果就是斜率:

在这里插入图片描述
接着写代码,我们将所有 x x x值保存在一个单维度的NumPy数组 x x x中,所有结果保存在另一个单维度数组 y y y中。为了执行相同的操作,我们将取 x [ 0 ] x[0] x[0] y [ 0 ] y[0] y[0]作为第一个点,然后取 x [ 1 ] x[1] x[1] y [ 1 ] y[1] y[1]作为第二个点。现在我们可以计算它们之间的斜率:

python">import matplotlib.pyplot as plt
import numpy as npdef f(x):return 2*xx = np.array(range(5))
y = f(x)plt.plot(x, y)
plt.show()print((y[1]-y[0]) / (x[1]-x[0]))
python">>>>
2.0

这条直线的斜率是 2 2 2并不令人意外。我们可以说 x x x y y y的影响度量是 2 2 2。我们可以以相同的方式计算任何线性函数的斜率,包括那些不那么明显的线性函数。

那么对于一个非线性函数如 f ( x ) = 2 x 2 f(x) = 2x^2 f(x)=2x2呢?

python">def f(x):return 2*x**2

此函数创建一个不形成直线的图形:
在这里插入图片描述

我们可以测量这条曲线的斜率吗?根据我们选择使用的 2 点,我们将测量不同的斜率:

python">y = f(x) # Calculate function outputs for new functionprint(x)
print(y)
python">>>>
[0 1 2 3 4]
[ 0 2 8 18 32]

现在讨论第一点:

python">print((y[1]-y[0]) / (x[1]-x[0]))
python">>>>
2

另一个:

python">print((y[3]-y[2]) / (x[3]-x[2]))
python">>>>
10

在这里插入图片描述

代码的可视化: https://nnfs.io/bro

我们如何衡量 x x x y y y在这个非线性函数中的影响呢?微积分提出,我们可以通过测量 x x x点(对于函数的一个特定输入值)的切线斜率来衡量,这给出了瞬时斜率(该点的斜率),即导数。切线是通过在曲线上的两个“无限接近”的点之间画一条线来创建的,但这条曲线必须在导数点处是可微分的。这意味着它必须是连续且平滑的(我们无法计算在我们可以描述为“尖角”的东西上的斜率,因为它包含无数个斜率)。然后,因为这是一条曲线,所以没有单一的斜率。斜率取决于我们在哪里测量它。举一个直接的例子,我们可以通过使用 x x x点和另一个也在 x x x点但加上一个非常小的增量,例如0.0001的点来近似这个函数在 x x x处的导数。这个数字是一个常见的选择,因为它不会引入太大的误差(在估计导数时)或导致整个表达式在数值上不稳定(由于浮点数分辨率, Δ x \Delta x Δx可能会四舍五入到0)。这让我们能够像以前一样计算斜率,但是在两个非常接近的点上进行,从而得到 x x x处斜率的良好近似:

python">import matplotlib.pyplot as plt
import numpy as npdef f(x):return 2*x**2p2_delta = 0.0001x1 = 1
x2 = x1 + p2_delta # add deltay1 = f(x1) # result at the derivation point
y2 = f(x2) # result at the other, close pointapproximate_derivative = (y2-y1)/(x2-x1)print(approximate_derivative)
python">>>>
4.0001999999987845

我们很快就会了解到,当 x = 1 x=1 x=1时, 2 x 2 2x^2 2x2的导数应该恰好是4。我们看到的差异(约为4.0002)来自于计算切线的方法。我们选择了一个足够小的 δ \delta δ,以尽可能准确地近似导数,但又足够大以防止四舍五入误差。

更详细地说,一个无限小的 δ \delta δ值将近似一个准确的导数;然而, δ \delta δ值需要在数值上稳定,意味着,我们的 δ \delta δ不能超过Python浮点精度的限制(不能太小,否则可能会被四舍五入为0,而我们知道,除以0是“非法”的)。因此,我们的解决方案受限于估计导数和保持数值稳定之间,因此引入了这个小但可见的误差。


3. 数值导数(The Numerical Derivative)

这种计算导数的方法称为数值微分——使用两个无限接近的点计算切线的斜率,或者像代码解决方案中那样——计算由两个“足够接近”的点构成的切线的斜率。我们可以通过以下内容来可视化为什么我们要在两个接近的点上执行这个操作:

在这里插入图片描述
在这里插入图片描述

代码的可视化: https://nnfs.io/cat

我们可以看到,这两个点越接近,切线就显得越正确。

继续进行数值微分,让我们可视化切线并观察它们如何随着计算位置的不同而变化。首先,我们将使用Numpyarange()函数使这个函数的图形更加细致,从而允许我们以更小的步骤绘图。np.arange()函数接收起始点、终点和步长参数,允许我们每次取0.001这样的小步长:

python">import matplotlib.pyplot as plt
import numpy as npdef f(x):return 2*x**2# np.arange(start, stop, step) to give us smoother line
x = np.arange(0, 5, 0.001)
y = f(x)plt.plot(x, y)
plt.show()

在这里插入图片描述

为了绘制这些切线,我们将推导出某一点处切线的函数,并在该点的图形上绘制切线。直线的函数是 y = m x + b y = mx + b y=mx+b。其中 m m m是我们已经计算出的斜率或近似导数, x x x是输入,这使得 b b b,即 y y y截距,需要我们来计算。斜率保持不变,但目前,您可以通过 y y y截距“上移”或“下移”线。我们已经知道 x x x m m m,但 b b b仍然未知。假设为了图形的目的令 m = 1 m=1 m=1,看看这实际意味着什么:

在这里插入图片描述

代码的可视化: https://nnfs.io/but

要计算 b b b,公式是 b = y − m x b = y - mx b=ymx

在这里插入图片描述

到目前为止,我们已经使用了两个点——我们想要计算导数的点和与其“足够接近”的点来计算导数的近似值。现在,根据上述 b b b的方程,导数的近似值,以及同一个“足够接近”的点(具体来说是它的 x x x y y y坐标),我们可以将它们代入方程中,得到该导数点处切线的 y y y截距。使用代码:

python">import matplotlib.pyplot as plt
import numpy as npdef f(x):return 2*x**2# np.arange(start, stop, step) to give us smoother line
x = np.arange(0, 5, 0.001)
y = f(x)plt.plot(x, y)# The point and the "close enough" point
p2_delta = 0.0001
x1 = 2
x2 = x1+p2_deltay1 = f(x1)
y2 = f(x2)print((x1, y1), (x2, y2))# Derivative approximation and y-intercept for the tangent line
approximate_derivative = (y2-y1)/(x2-x1)
b = y2 - approximate_derivative*x2# We put the tangent line calculation into a function so we can call
# it multiple times for different values of x
# approximate_derivative and b are constant for given function
# thus calculated once above this function
def tangent_line(x):return approximate_derivative*x + b# plotting the tangent line
# +/- 0.9 to draw the tangent line on our graph
# then we calculate the y for given x using the tangent line function
# Matplotlib will draw a line for us through these points
to_plot = [x1-0.9, x1, x1+0.9]
plt.plot(to_plot, [tangent_line(i) for i in to_plot])print('Approximate derivative for f(x)', f'where x = {x1} is {approximate_derivative}')plt.show()
python">>>>
(2, 8) (2.0001, 8.000800020000002)
Approximate derivative for f(x) where x = 2 is 8.000199999998785

在这里插入图片描述

橙色的线是函数 f ( x ) = 2 x 2 f(x) = 2x^2 f(x)=2x2 x = 2 x=2 x=2处的近似切线。我们为什么要关心这个?你很快会发现,我们只关心这条切线的斜率,但是可视化和理解切线都非常重要。我们关心切线的斜率,因为它告诉我们 x x x在特定点对这个函数的影响,被称为瞬时变化率。我们将使用这个概念来确定一个特定的权重或偏置对给定样本的总损失函数的影响。现在,对于不同的 x x x值,我们可以观察到对函数的不同影响。我们可以继续之前的代码,查看不同输入( x x x)的切线 - 我们将代码的一部分放在一个循环中,遍历示例 x x x值,并绘制多条切线:

python">import matplotlib.pyplot as plt
import numpy as npdef f(x):return 2*x**2
# np.arange(start, stop, step) to give us a smoother curve
x = np.array(np.arange(0,5,0.001))
y = f(x)plt.plot(x, y)colors = ['k','g','r','b','c']def approximate_tangent_line(x, approximate_derivative):return (approximate_derivative*x) + bfor i in range(5):p2_delta = 0.0001x1 = ix2 = x1+p2_deltay1 = f(x1)y2 = f(x2)print((x1, y1), (x2, y2))approximate_derivative = (y2-y1)/(x2-x1)b = y2-(approximate_derivative*x2)to_plot = [x1-0.9, x1, x1+0.9]plt.scatter(x1, y1, c=colors[i])plt.plot([point for point in to_plot], [approximate_tangent_line(point, approximate_derivative) for point in to_plot], c=colors[i])print('Approximate derivative for f(x)', f'where x = {x1} is {approximate_derivative}')plt.show()
python">>>>
(0, 0) (0.0001, 2e-08)
Approximate derivative for f(x) where x = 0 is 0.00019999999999999998
(1, 2) (1.0001, 2.00040002)
Approximate derivative for f(x) where x = 1 is 4.0001999999987845
(2, 8) (2.0001, 8.000800020000002)
Approximate derivative for f(x) where x = 2 is 8.000199999998785
(3, 18) (3.0001, 18.001200020000002)
Approximate derivative for f(x) where x = 3 is 12.000199999998785
(4, 32) (4.0001, 32.00160002)
Approximate derivative for f(x) where x = 4 is 16.000200000016548

在这里插入图片描述

对于这个简单的函数 f ( x ) = 2 x 2 f(x) = 2x^2 f(x)=2x2,我们通过近似导数(即切线的斜率)并没有支付高昂的代价,并且得到了足够接近我们需求的值。

问题在于,我们神经网络中实际使用的函数并不简单。损失函数包含了所有的层、权重和偏置——这是一个在多个维度上运作的绝对庞大的函数!使用数值微分计算导数需要对单个参数更新进行多次前向传递(我们将在第10章讨论参数更新)。我们需要执行前向传递作为参考,然后通过增量值更新单个参数,并再次通过我们的模型执行前向传递以查看损失值的变化。接下来,我们需要计算导数并恢复我们为这次计算所做的参数更改。我们必须对每个权重和偏置以及每个样本重复这一过程,这将非常耗时。我们还可以将这种方法视为强行计算导数的方法。再次强调,正如我们快速涵盖了许多术语,导数是作为单个参数输入的函数的切线的斜率。我们将利用这种能力计算损失函数在每个权重和偏置点的斜率——这将引导我们进入多变量函数,这是一个涉及多个参数的函数,是下一章的主题——偏导数。


4. 解析导数(The Analytical Derivative)

现在我们对什么是导数有了更好的理解,如何计算数值导数(也称为通用导数),以及为什么这种方法不适合我们,我们可以继续探讨解析导数,这是我们将在代码中实现的实际导数解法。

在数学中,解决问题通常有两种方法:数值方法和解析方法。数值解决方法涉及计算一个数字来找到解决方案,就像上面的近似导数方法一样。数值解也是一种近似。另一方面,解析方法提供了精确且在计算方面更快的解决方案。然而,正如我们将很快学到的,识别给定函数的导数的解析解在复杂性上会有所不同,而数值方法的复杂度从不增加——它总是调用两次方法,用两个输入来计算某一点的近似导数。一些解析解非常明显,有些可以通过简单的规则计算,而一些复杂的函数可以分解成更简单的部分,并使用所谓的链式法则进行计算。我们可以利用已经证明的某些函数的导数解,而其他一些函数——像我们的损失函数——可以通过上述方法的组合来解决。

要使用解析方法计算函数的导数,我们可以将它们分解为简单的基本函数,找到这些函数的导数,然后应用链式法则(我们很快会解释这一点)来得到完整的导数。为了开始建立直觉,让我们从简单的函数及其各自的导数开始。简单常数函数的导数:

在这里插入图片描述

在这里插入图片描述

代码的可视化: https://nnfs.io/cow

当计算函数的导数时,回想一下导数可以被解释为斜率。在这个例子中,这个函数的结果是一条水平线,因为任何 x x x的输出值都是1:

通过观察,很明显导数等于0,因为从一个 x x x值到任何其他 x x x值都没有变化(即,没有斜率)。

到目前为止,我们通过取一个单一参数 x x x来计算函数的导数,在每个例子中我们都是这样做的。这在计算偏导数时会有所改变,因为偏导数涉及到具有多个参数的函数,并且我们将一次只对其中一个参数计算导数。目前,对于导数,它总是相对于单一参数。为了表示导数,我们可以使用撇号记法,例如对于函数 f ( x ) f(x) f(x),我们加上一个撇号 ( ′ ) (') () f ′ ( x ) f'(x) f(x)。对于我们的例子, f ( x ) = 1 f(x) = 1 f(x)=1,其导数 f ′ ( x ) = 0 f'(x) = 0 f(x)=0。我们还可以使用的另一种记法称为莱布尼茨记法——依赖于撇号记法和用莱布尼茨记法书写导数的多种方式如下:


这些记法的含义相同 —— 函数的导数(相对于 x x x)。在接下来的例子中,我们将使用这两种记法,因为有时使用一种记法或另一种记法更方便。我们也可以在同一个等式中同时使用这两种记法。

5. 总结

规则:

函数的常数倍的导数等于该函数导数的常数倍:

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

我们在这里使用 x x x的值而不是整个函数 f ( x ) f(x) f(x),因为整个函数的导数的计算方式有些不同。我们将在下一章中解释这个概念,并结合链式法则进行讲解。

既然我们已经了解了什么是导数以及如何解析地计算它们,稍后我们将在代码中实现,我们可以更进一步,在下一章中讨论偏导数。

本章的章节代码、更多资源和勘误表:https://nnfs.io/ch7


http://www.ppmy.cn/server/143154.html

相关文章

【使用 Docker 搭建云原生后端环境的详细教程】

安装 Docker: 对于 CentOS 7/8: 卸载旧版本(如果存在):sudo yum remove docker \docker-client \docker-client-latest \docker-common \docker-latest \docker-latest-logrotate \docker-logrotate \<

taro框架h5项目打包后页面空白 解决办法

最近正在用taro框架&#xff0c;写一个h5页面&#xff0c;本地打开页面好好的&#xff0c;打包之后页面就一片空白&#xff0c;经过各方搜查&#xff0c;找到了解决办法&#xff0c;以此记录下来&#xff0c;希望可以帮助到和我遇到同样问题朋友们 如果Nginx设置了二层目录&am…

Spring Boot3自定义starter

1、加入必要依赖 plugins {id javaid org.springframework.boot version 3.2.6id io.spring.dependency-management version 1.1.5 } group org.example.test.starter version 1.1.0jar{enabledtrue// resolveMainClassName }java {toolchain {languageVersion JavaLanguage…

开源项目低代码表单设计器FcDesigner扩展自定义组件

开源项目低代码表单设计器FcDesigner中的通过将自定义组件集成到设计器中&#xff0c;您可以添加额外的界面元素和功能&#xff0c;从而增强设计器的适用性和灵活性。以下是详细步骤&#xff0c;以帮助您创建、导入、注册和配置自定义组件。 源码地址: Github | Gitee | 文档 …

CompressAI安装!!!

我就不说废话了&#xff0c;直接给教程&#xff0c;还是非常简单的 但是我看了好多帖子&#xff0c;都没有说明情况 一定要看最后最后的那个注释 正片开始&#xff1a; 一共有三种方式&#xff1a; 第一种就是本机安装&#xff1a; 在网址上下载对应版本Links for compre…

IP数据云 识别和分析tor、proxy等各类型代理

在网络上使用代理&#xff08;tor、proxy、relay等&#xff09;进行访问的目的是为了规避网络的限制、隐藏真实身份或进行其他的不正当行为。 对代理进行识别和分析可以防止恶意攻击、监控和防御僵尸网络和提高防火墙效率等&#xff0c;同时也可以对用户行为进行分析&#xff…

SOA(面向服务架构)全面解析

1. 引言 什么是SOA&#xff08;面向服务架构&#xff09; SOA&#xff08;Service-Oriented Architecture&#xff0c;面向服务架构&#xff09;是一种将应用程序功能以“服务”的形式进行模块化设计的架构风格。这些服务是独立的功能模块&#xff0c;它们通过定义明确的接口…

最长连续序列

题目描述 给定一个未排序的整数数组 nums &#xff0c;找出数字连续的最长序列&#xff08;不要求序列元素在原数组中连续&#xff09;的长度。 请你设计并实现时间复杂度为 O(n) 的算法解决此问题。 示例 1&#xff1a; 输入&#xff1a;nums [100,4,200,1,3,2] 输出&#…