神经网络中的Adam

ops/2025/3/1 23:49:39/

Adam(Adaptive Moment Estimation)是一种广泛使用的优化算法,结合了RMSprop和动量(Momentum)的优点。它通过计算梯度的一阶矩估计(mean)和二阶矩估计(uncentered variance),为每个参数提供自适应学习率。Adam由Diederik P. Kingma和Jimmy Ba在2014年的论文《Adam: A Method for Stochastic Optimization》中提出。

### Adam的核心思想

Adam的主要特点是:
- **自适应学习率**:根据参数的梯度一阶矩(均值)和二阶矩(方差)自动调整学习率。
- **动量**:引入类似于传统动量的概念来加速SGD在相关方向上的进展,并抑制震荡。
- **偏置校正**:对初期的矩估计进行偏差校正,以应对开始阶段估计不准确的问题。

### 更新规则

对于时间步 \( t \),某个参数 \( w \) 的更新过程如下:

1. **计算梯度**:
   \[ g_t = \nabla_{w} J(w_{t-1}) \]
   
   这里,\( g_t \) 是损失函数 \( J \) 对参数 \( w \) 在时间步 \( t \) 的梯度。

2. **计算一阶矩估计(均值)**:
   \[ m_t = \beta_1 m_{t-1} + (1 - \beta_1)g_t \]
   
3. **计算二阶矩估计(未中心化的方差)**:
   \[ v_t = \beta_2 v_{t-1} + (1 - \beta_2)g_t^2 \]
   
   其中,\( \beta_1 \) 和 \( \beta_2 \) 分别是用于控制一阶矩和二阶矩估计的指数衰减率,默认情况下分别设置为 0.9 和 0.999。

4. **偏差校正**:
   \[ \hat{m}_t = \frac{m_t}{1-\beta_1^t} \]
   \[ \hat{v}_t = \frac{v_t}{1-\beta_2^t} \]
   
   这一步是为了修正初始时刻的偏差,因为在训练初期,\( m_t \) 和 \( v_t \) 可能会偏向零。

5. **参数更新**:
   \[ w_t = w_{t-1} - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \hat{m}_t \]
   
   其中,\( \eta \) 是学习率,\( \epsilon \) 是一个很小的常数(例如 \( 10^{-8} \)),用于确保数值稳定性,避免除以零的情况。

### 特点与优势

- **高效性**:Adam通常比其他自适应学习率方法如Adagrad或RMSprop更快收敛。
- **适用于非平稳目标**:由于其使用了移动平均,因此更适合处理随着时间变化的目标函数。
- **不需要手动调节学习率**:相比标准SGD,Adam减少了对超参数(特别是学习率)精细调节的需求。

### 实践中的应用

Adam因其良好的性能和易用性,在深度学习领域得到了广泛应用。无论是图像识别、自然语言处理还是强化学习等领域,Adam都是首选的优化器之一。下面是一个使用TensorFlow/Keras实现Adam的例子:

```python
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

# 创建模型
model = Sequential([Dense(1, input_shape=(8,))])

# 使用Adam优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

# 编译模型
model.compile(optimizer=optimizer, loss='mse')

# 假设我们有一些数据x_train和y_train
# model.fit(x_train, y_train, epochs=10)
```

在这个例子中,`learning_rate` 参数可以调整以适应特定任务的需求,但默认值通常已经足够有效。此外,Keras的Adam优化器还允许进一步定制化,比如调整 \( \beta_1 \) 和 \( \beta_2 \) 的值等。


http://www.ppmy.cn/ops/162359.html

相关文章

golang安装(1.23.6)

1.切换到安装目录 cd /usr/local 2.下载安装包 wget https://go.dev/dl/go1.23.6.linux-amd64.tar.gz 3.解压安装包 sudo tar -C /usr/local -xzf go1.23.6.linux-amd64.tar.gz 4.配置环境变量 vi /etc/profile export PATH$…

conda怎么迁移之前下载的环境包,把python从3.9升级到3.10

克隆旧环境(保留旧环境作为备份) conda create -n cloned_env --clone old_env 在克隆环境中直接升级 Python conda activate cloned_env conda install python3.10 升级 Python 后出现 所有包导入失败 的问题,通常是因为依赖包与新 Pyth…

【开源库 | jsoncpp】jsoncpp 的使用教程,附带读写json数据的例子源码

😁博客主页😁:🚀https://blog.csdn.net/wkd_007🚀 🤑博客内容🤑:🍭嵌入式开发、Linux、C语言、C、数据结构、音视频🍭 🤣本文内容🤣&a…

unity学习60: 滑动条 和 滚动条 滚动区域

目录 1 滚动条 scrollbar 1.1 创建滚动条 1.2 scrollbar的子物体 1.3 scrollbar的属性 2 滚动视图 scroll View 2.1 创建1个scroll View 2.1.1 实际类比,网页就是一个 scroll view吧 2.2 子物体构成 2.3 核心component : Scroll Rect 3 可视区域 view p…

MCU+RTOS学习笔记1

0 学习内容 06_基于 Cubemx 实现按键控制 LED 灯(裸机) 1 知识学习 参考:1_Jlink介绍和使用 1.JTAG 1.JTAG JTAG是一种国际标准测试协议,主要用于芯片内部测试,标准的JTAG接口为4线:TMS、TCK、TDI、T…

Python--内置模块和开发规范(上)

1. 内置模块 1.1 JSON 模块 核心功能 序列化:Python 数据类型 → JSON 字符串 import json data [{"id": 1, "name": "武沛齐"}, {"id": 2, "name": "Alex"}] json_str json.dumps(data, ensure_a…

2025-02-28 学习记录--C/C++-C语言 scanf 中,%s 不需要加

合抱之木,生于毫末;九层之台,起于累土;千里之行,始于足下。💪🏻 C语言 scanf 中,%s 不需要加 & 格式化符号变量类型是否需要加 &原因%s字符数组不需要数组名本身就是指针&a…

夏普比率(Sharpe Ratio):衡量投资风险与收益的黄金标准(中英双语)

夏普比率(Sharpe Ratio):衡量投资风险与收益的黄金标准 📊📈 📌 什么是夏普比率? 夏普比率(Sharpe Ratio) 由诺贝尔经济学奖得主 威廉夏普(William F. Shar…