BP 网络的标准学习算法及其实现

server/2024/11/12 21:21:51/

BP 网络的标准学习算法及其实现

一、引言

BP(Back Propagation)神经网络是一种广泛应用于机器学习和人工智能领域的神经网络模型。它通过反向传播算法来调整网络的权重,以最小化预测输出和实际输出之间的误差。BP 网络的标准学习算法对于理解和实现神经网络的训练过程至关重要,本文将详细介绍这些算法并给出相应的代码示例。

二、BP 网络的基本结构

BP 网络通常由输入层、若干隐藏层和输出层组成。每层由多个神经元构成,神经元之间通过合适的权重连接。输入层接收外部数据,隐藏层对数据进行特征提取和转换,输出层输出最终的预测结果。

三、BP 网络的标准学习算法

(一)梯度下降算法

  1. 原理
    • 梯度下降算法是 BP 网络中最基本的优化算法。其目标是通过迭代地调整网络权重,使损失函数(如均方误差函数)的值最小化。损失函数关于权重的梯度表示了损失函数在当前权重下增长最快的方向,而梯度下降算法则沿着梯度的负方向更新权重,以期望找到损失函数的最小值。
    • 对于一个具有多个权重参数 w w w 的 BP 网络,损失函数 J ( w ) J(w) J(w) 的梯度 ∇ J ( w ) \nabla J(w) J(w) 可以通过链式法则计算。在每次迭代中,权重更新公式为: w = w − α ∇ J ( w ) w = w - \alpha \nabla J(w) w=wαJ(w),其中 α \alpha α学习率,它决定了每次更新的步长。如果学习率过大,可能会导致算法无法收敛甚至发散;如果学习率过小,算法收敛速度会很慢。
  2. 代码示例
    以下是一个简单的使用梯度下降算法训练单神经元的代码示例,用于拟合一个简单的线性函数 y = 2 x + 1 y = 2x + 1 y=2x+1
import numpy as np# 激活函数,这里使用恒等函数(在线性回归场景下)
def activation_function(x):return x# 损失函数,这里使用均方误差
def loss_function(y_pred, y_true):return np.mean((y_pred - y_true) ** 2)# 训练数据
x = np.array([[1], [2], [3], [4], [5]])
y = np.array([[3], [5], [7], [9], [11]])# 初始化权重和偏置
weight = np.random.rand()
bias = np.random.rand()
learning_rate = 0.01
epochs = 100for epoch in range(epochs):# 前向传播y_pred = activation_function(x * weight + bias)loss = loss_function(y_pred, y)# 计算梯度d_loss_d_weight = np.mean((y_pred - y) * x)d_loss_d_bias = np.mean(y_pred - y)# 更新权重和偏置weight = weight - learning_rate * d_loss_d_weightbias = bias - learning_rate * d_loss_d_biasif epoch % 10 == 0:print(f'Epoch {epoch}: Loss = {loss}')print(f'Final weight: {weight}, Final bias: {bias}')

(二)随机梯度下降算法(SGD)

  1. 原理
    • 梯度下降算法在每次迭代时都使用整个训练数据集来计算梯度,当数据集很大时,计算成本很高。随机梯度下降算法则每次从训练数据集中随机选择一个样本进行梯度计算和权重更新。这样可以大大加快训练速度,但由于每次只使用一个样本,梯度的估计可能会有较大的噪声,导致收敛路径可能会更加曲折。
    • 在 SGD 中,权重更新公式与梯度下降类似,但每次只针对一个样本计算梯度。例如,对于一个样本 ( x i , y i ) (x_i, y_i) (xi,yi),权重更新公式为: w = w − α ∇ J i ( w ) w = w - \alpha \nabla J_i(w) w=wαJi(w),其中 ∇ J i ( w ) \nabla J_i(w) Ji(w) 是损失函数关于权重 w w w 在样本 i i i 上的梯度。
  2. 代码示例
    以下是使用随机梯度下降算法训练上述线性回归问题的代码:
import numpy as np# 激活函数,这里使用恒等函数(在线性回归场景下)
def activation_function(x):return x# 损失函数,这里使用均方误差
def loss_function(y_pred, y_true):return np.mean((y_pred - y_true) ** 2)# 训练数据
x = np.array([[1], [2], [3], [4], [5]])
y = np.array([[3], [5], [7], [9], [11]])# 初始化权重和偏置
weight = np.random.rand()
bias = np.random.rand()
learning_rate = 0.01for epoch in range(100):for i in range(len(x)):# 随机选择一个样本sample_x = x[i]sample_y = y[i]# 前向传播y_pred = activation_function(sample_x * weight + bias)loss = loss_function(y_pred, sample_y)# 计算梯度d_loss_d_weight = (y_pred - sample_y) * sample_xd_loss_d_bias = y_pred - sample_y# 更新权重和偏置weight = weight - learning_rate * d_loss_d_weightbias = bias - learning_rate * d_loss_d_biasif epoch % 10 == 0:y_pred_all = activation_function(x * weight + bias)loss_all = loss_function(y_pred_all, y)print(f'Epoch {epoch}: Loss = {loss_all}')print(f'Final weight: {weight}, Final bias: {bias}')

(三)小批量梯度下降算法(Mini - Batch Gradient Descent)

  1. 原理
    • 小批量梯度下降算法是梯度下降和随机梯度下降的一种折衷。它每次从训练数据集中选取一小批(mini - batch)样本进行梯度计算和权重更新。这样既可以利用向量化计算的优势(相比随机梯度下降),又可以在一定程度上减少计算量(相比梯度下降),同时也能获得比随机梯度下降更稳定的梯度估计。
    • 假设小批量大小为 m m m,在每次迭代中,从训练数据集 D D D 中随机选取一个小批量样本 B = { ( x 1 , y 1 ) , ( x 2 , y 2 ) , ⋯ , ( x m , y m ) } B = \{(x_1, y_1), (x_2, y_2), \cdots, (x_m, y_m)\} B={(x1,y1),(x2,y2),,(xm,ym)}。权重更新公式为: w = w − α ∇ J B ( w ) w = w - \alpha \nabla J_B(w) w=wαJB(w),其中 ∇ J B ( w ) \nabla J_B(w) JB(w) 是损失函数关于权重 w w w 在小批量样本 B B B 上的梯度。
  2. 代码示例
    以下是使用小批量梯度下降算法训练线性回归问题的代码,假设小批量大小为 2:
import numpy as np# 激活函数,这里使用恒等函数(在线性回归场景下)
def activation_function(x):return x# 损失函数,这里使用均方误差
def loss_function(y_pred, y_true):return np.mean((y_pred - y_true) ** 2)# 训练数据
x = np.array([[1], [2], [3], [4], [5]])
y = np.array([[3], [5], [7], [9], [11]])# 初始化权重和偏置
weight = np.random.rand()
bias = np.random.rand()
learning_rate = 0.01
batch_size = 2
epochs = 100for epoch in range(epochs):for i in range(0, len(x), batch_size):end_index = min(i + batch_size, len(x))batch_x = x[i:end_index]batch_y = y[i:end_index]# 前向传播y_pred = activation_function(batch_x * weight + bias)loss = loss_function(y_pred, batch_y)# 计算梯度d_loss_d_weight = np.mean((y_pred - batch_y) * batch_x)d_loss_d_bias = np.mean(y_pred - batch_y)# 更新权重和偏置weight = weight - learning_rate * d_loss_d_weightbias = bias - learning_rate * d_loss_d_biasif epoch % 10 == 0:y_pred_all = activation_function(x * weight + bias)loss_all = loss_function(y_pred_all, y)print(f'Epoch {epoch}: Loss = {loss_all}')print(f'Final weight: {weight}, Final bias: {bias}')

四、总结

BP 网络的标准学习算法包括梯度下降、随机梯度下降和小批量梯度下降等。梯度下降算法在处理大规模数据集时计算成本高,随机梯度下降算法收敛路径可能不稳定,而小批量梯度下降算法在两者之间取得了较好的平衡。在实际应用中,需要根据数据集的大小、计算资源和模型的复杂度等因素选择合适的学习算法,以有效地训练 BP 网络并获得良好的预测性能。同时,这些算法还可以进一步改进和优化,例如使用自适应学习率等方法来提高训练效率和收敛速度。


http://www.ppmy.cn/server/140781.html

相关文章

git 多账号配置

windows下git多账号配置详解_git配置多个用户名和密码-CSDN博客 windows下git多账号配置详解_git配置多个用户名和密码-CSDN博客 windows下git多账号配置详解_git配置多个用户名和密码-CSDN博客

Ascend Extension for PyTorch的源码解析

1 源码下载 Ascend对pytorch代码的适配,可从以下链接中获取。 Ascend/pytorch 执行如下命令即可。 git clone https://gitee.com/ascend/pytorch.git2 目录结构解析 源码下载后,如果需要编译torch-npu,最好保持pytorch的源码版本匹配&…

设计模式-七个基本原则之一-开闭原则 + SpringBoot案例

开闭原则:(SRP) 面向对象七个基本原则之一 对扩展开放:软件实体(类、模块、函数等)应该能够通过增加新功能来进行扩展。对修改关闭:一旦软件实体被开发完成,就不应该修改它的源代码。 要看实际场景,比如组内…

【系统面试篇】其他相关题目——虚拟内存、局部性原理、分页、分块、页面置换算法

目录 一、相关问题 1. 什么是虚拟内存?为什么需要虚拟内存? (1)内存扩展 (2)内存隔离 (3)物理内存管理 (4)页面交换 (5)内存映…

clickhouse自增id的处理

msyql 中创建数据表的时候可以通过AUTO_INCREMENT 来实现,clickhouse中可以通过其他方式来处理 一、 默认值 创建表时可以实用默认值,该列值可以自动递增。如下所示 CREATE TABLE my_table ( id UInt32 DEFAULT IDENTITY(AUTO_INCREMENT), name Strin…

亲测有效:Maven3.8.1使用Tomcat8插件启动项目

我本地maven的settings.xml文件中的配置&#xff1a; <mirror><id>aliyunmaven</id><mirrorOf>central</mirrorOf><name>阿里云公共仓库</name><url>https://maven.aliyun.com/repository/public</url> </mirror>…

vue种ref跟reactive的区别?

‌Vue中的ref和reactive的主要区别在于它们处理的数据类型、实现原理以及使用方式。‌ 处理的数据类型 ‌ref‌&#xff1a;可以处理基本数据类型&#xff08;如数字、字符串、布尔值&#xff09;和对象。ref通过Object.defineProperty()的get和set方法来实现响应式&#xff…

使用 Redux 在 Flutter鸿蒙next 中实现状态管理

在 Flutter 中进行状态管理是开发应用程序时的一个关键问题。Flutter 提供了多种解决方案来管理应用的状态&#xff0c;其中 Redux 是一种广泛使用且功能强大的状态管理库。虽然 Redux 最初是为 JavaScript 和 React 设计的&#xff0c;但它的核心概念非常适用于 Flutter&#…