机器学习算法模型系列——Adam算法

embedded/2024/11/25 2:04:50/

Adam是一种自适应学习率的优化算法,结合了动量和自适应学习率的特性。

主要思想是根据参数的梯度来动态调整每个参数的学习率。

核心原理包括:

  1. 动量(Momentum):Adam算法引入了动量项,以平滑梯度更新的方向。这有助于加速收敛并减少震荡。

  2. 自适应学习率:Adam算法计算每个参数的自适应学习率,允许不同参数具有不同的学习速度。

  3. 偏差修正(Bias Correction):Adam算法在初期迭代中可能受到偏差的影响,因此它使用偏差修正来纠正这个问题。

Adam相关公式

初始化:

  • 参数:eq?%5Cbeta

  • 学习率:eq?%5Calpha

  • 梯度估计的移动平均(一阶矩):eq?m%3D0

  • 梯度平方的移动平均(二阶矩):eq?v%3D0

  • 时间步数:eq?t%3D0

每个迭代步骤:

  1. eq?t%3Dt+1
  2. 计算梯度:eq?g_%7Bt%7D%20%3D%5Cbigtriangledown%20f%20_%7Bt%7D%28%5Ctheta%20_%7Bt%7D%29

  3. 更新一阶矩:eq?m_%7Bt%7D%20%3D%5Cbeta_%7B1%7D%5Ccdot%20m_%7Bt-1%7D+%281-%5Cbeta_%7B1%7D%29%5Ccdot%20g_%7Bt%7D

  4. 更新二阶矩:eq?v_%7Bt%7D%20%3D%5Cbeta_%7B2%7D%5Ccdot%20v_%7Bt-1%7D+%281-%5Cbeta_%7B2%7D%29%5Ccdot%20g_%7Bt%7D%5E%7B2%7D

  5. 修正偏差(Bias Correction): eq?%5Chat%7Bm%7D_%7Bt%7D%20%3D%5Cfrac%7Bm_%7Bt%7D%7D%7B%281-%5Cbeta_%7B1%7D%5E%7Bt%7D%29%7D和 eq?%5Chat%7Bv%7D_%7Bt%7D%20%3D%5Cfrac%7Bv%7Bt%7D%7D%7B%281-%5Cbeta_%7B2%7D%5E%7Bt%7D%29%7D

  6. 更新参数:eq?%7B%5Ctheta%7D_%7Bt+1%7D%20%3D%7B%5Ctheta%7D_%7Bt%7D%20-%5Calpha%20%5Ccdot%20%5Cfrac%7B%5Chat%7Bm%7D%7Bt%7D%7D%7B%28%5Csqrt%7Bv_%7Bt%7D%7D-%5Cvarepsilon%20%29%7D,其中 eq?%5Cvarepsilon 是一个小的常数,以防分母为零。

项目:基于Adam优化算法的神经网络训练

在这个项目中,我们将使用Adam优化算法来训练一个简单的神经网络,以解决二分类问题。我们将深入讨论Adam算法的原理和公式,并展示如何在Python中实施它。最后,我们将绘制学习曲线,以可视化模型的训练进展。

项目:基于Adam优化算法的神经网络训练

在这个项目中,我们将使用Adam优化算法来训练一个简单的神经网络,以解决二分类问题。我们将深入讨论Adam算法的原理和公式,并展示如何在Python中实施它。最后,我们将绘制学习曲线,以可视化模型的训练进展。

模型训练

使用Python代码实现Adam算法来训练一个二分类的神经网络。

使用Python中的NumPy库来进行计算,并使用一个合成的数据集来演示。

import numpy as np
import matplotlib.pyplot as plt# 定义模型和数据
np.random.seed(42)
X = np.random.rand(100, 2)  # 特征数据
y = (X[:, 0] + X[:, 1] > 1).astype(int)  # 二分类标签# 定义神经网络模型
def sigmoid(x):return 1 / (1 + np.exp(-x))def predict(X, weights):return sigmoid(np.dot(X, weights))# 初始化参数和超参数
theta = np.random.rand(2)  # 参数初始化
alpha = 0.1  # 学习率
beta1 = 0.9  # 一阶矩衰减因子
beta2 = 0.999  # 二阶矩衰减因子
epsilon = 1e-8  # 用于防止分母为零# 初始化Adam算法所需的中间变量
m = np.zeros(2)
v = np.zeros(2)
t = 0# 训练模型
num_epochs = 100
for epoch in range(num_epochs):for i in range(len(X)):t += 1gradient = (predict(X[i], theta) - y[i]) * X[i]m = beta1 * m + (1 - beta1) * gradientv = beta2 * v + (1 - beta2) * gradient**2m_hat = m / (1 - beta1**t)v_hat = v / (1 - beta2**t)theta -= alpha * m_hat / (np.sqrt(v_hat) + epsilon)# 输出训练后的参数
print("训练完成后的参数:", theta)# 定义损失函数
def loss(X, y, weights):y_pred = predict(X, weights)return -np.mean(y * np.log(y_pred) + (1 - y) * np.log(1 - y_pred))# 记录损失值
loss_history = []
for i in range(len(X)):loss_history.append(loss(X[i], y[i], theta))# 绘制损失函数曲线
plt.plot(range(len(X)), loss_history)
plt.xlabel("Iteration")
plt.ylabel("Loss Function Value")
plt.title("Change in Loss Function Over Time")
plt.show()

这个图形将显示损失函数值随着迭代次数的减小而减小,这表明Adam优化算法成功地训练了模型。

 

0c28a34bb095b6c7dd66ad815c82823e.png

 

 


http://www.ppmy.cn/embedded/139360.html

相关文章

单片机_day8_时钟系统+定时器+PWM

目录 1.2 时钟源 1.3 STM32U5时钟源 1.4 时钟树 1.5 STM32CubeMX时钟树配置 2. Systick定时器 2.1 概念 2.2 工作原理 2.3 滴答定时器分析 3. TIM定时器 3.1 基本概念 3.2 STM32U5定时器 3.3 ​​​​​​​定时器框图 3.3.1 ​​​​​​​预分频器 3.3.2 ​​​​​​​…

php 使用mqtt

在 Webman 框架中使用 MQTT 进行消息的发布和订阅,你可以借助 PHP 的 MQTT 客户端库,比如 phpMQTT。以下是一个简单的示例,展示了如何在 Webman 中使用 MQTT 发布和订阅消息。 安装 phpMQTT 首先,你需要通过 Composer 安装 phpMQ…

统计机器学习——线性回归与分类

chapter2 线性回归与分类 例2.1 import pandas as pd data = pd.read_csv("../data/第2章数据/diabetes.csv",index_col=0)Index = data.columns xtitle = [index for index in Index if x. in index] x2title = [index for index in Index if x2. in index] xdata…

【实操之 图像处理与百度api-python版本】

1 cgg带你建个工程 如图 不然你的pip baidu-aip 用不了 先对图片进行一点处理 $ 灰度处理 $ 滤波处理 参考 import cv2 import os def preprocess_images(input_folder, output_folder):# 确保输出文件夹存在if not os.path.exists(output_folder):os.makedirs(output_fol…

数据结构(双向链表——c语言实现)

双向链表相比于单向链表的优势: 1. 双向遍历的灵活性 双向链表:由于每个节点都包含指向前一个节点和下一个节点的指针,因此可以从头节点遍历到尾节点,也可以从尾节点遍历到头节点。这种双向遍历的灵活性使得在某些算法和操作中&a…

EDA实验设计-led灯管动态显示;VHDL;Quartus编程

EDA实验设计-led灯管动态显示;VHDL;Quartus编程 引脚配置实现代码RTL引脚展示现象记录效果展示 引脚配置 #------------------GLOBAL--------------------# set_global_assignment -name RESERVE_ALL_UNUSED_PINS "AS INPUT TRI-STATED" set_…

libjpeg库——图像压缩与解压的核心技术

引言 在数字图像处理领域,图像压缩与解压技术扮演着至关重要的角色。随着数字图像的广泛应用,高效地存储和传输图像数据成为了一项关键技术需求。libjpeg库,作为一个开源的图像压缩解压缩库,凭借其丰富的功能和灵活的接口&#x…

数据库基础(MySQL)

1. 数据库基础 1.1 什么是数据库 存储数据用文件就可以了,为什么还要弄个数据库? 文件保存数据有以下几个缺点: 文件的安全性问题文件不利于数据查询和管理文件不利于存储海量数据文件在程序中控制不方便 数据库存储介质: 磁盘内存 为…