关于深度学习参数寻优的一些介绍

ops/2025/3/19 16:34:10/

深度学习中,参数是十分重要的,严重影响预测的结果。而具体在深度学习中,如何让模型自己找到最合适的参数(权重与偏置等),这就是深度学习一词中“学习”的核心含义。在本文中,我将介绍除梯度下降算法以外的其他几个寻找最优参数的方法,即Momentum、AdaGrad、RMSProp、Adam算法。

一、梯度下降算法

1.1 概述

在之前的文章中,我有介绍过梯度下降算法,即SGD算法,所以这里引用《Deep Learning from Scrach》(斋藤康毅著)一书中的描述,即随机梯度下降就像蒙住了眼睛的旅行家下山,它可以根据哪里坡度大而选择走那条路以到达最低点。如果用公式表示就是:

W\leftarrow W-\eta \frac{\partial L}{\partial W}

其中w就是需要更新的参数,而\eta自然就是学习率。

1.2 缺点

虽然SGD算法在许多时候会有明显效果,但在函数的形状是非均向的,比如蜿蜒形的,此时该方法的搜索效率就会很差。而其原因就是梯度方向没有指向最小值方向。

二、Momentum方法

2.1 原理

那么,在上述问题中,我们可以考虑引入类似于“惯性”的这样一变量,那么原来SGD算法中的公式就变形为:

W\leftarrow W+v

v\leftarrow \alpha v-\eta \frac{\partial L}{\partial W}

即:W\leftarrow W - \eta\frac{\partial L}{\partial W}+\alpha v

其中v可以表示物体在梯度上受到的力,而也正是这个力,使得其速度增加。

2.2 代码
def momentum(f, grad, start, lr=0.1, beta=0.9, n=100, plot=True):x = startv = 0history = [x]for _ in range(n):g = grad(x)v = beta * v + (1 - beta) * gx = x - lr * vhistory.append(x)momentum_path = np.array(history)if plot:plt.scatter(momentum_path, f(momentum_path), c='red', s=20, label="Momentum")plt.xlabel("x")plt.ylabel("f(x)")plt.legend()plt.title("Optimization Paths of Momentum Algorithms")plt.show()return momentum_path

三、AdaGrad方法

3.1 原理

在关于选择学习率的参数上时,如果学习率过小,会导致学习花费过多时间,如果过大,则导致学习发散而不能收敛,因而不能正确运行。那么,我们考虑一种学习率衰减的方法,即随着学习的进行,使得学习率逐渐减小,也就是一种自适应学习率。所以AdaGrad中,Ada来自于单词Adaptive。

公式如下:

h\leftarrow h+\frac{\partial L}{\partial W}\bigodot \frac{\partial L}{\partial W}

W\leftarrow W - \eta \frac{1}{\sqrt{h}}\frac{\partial L}{\partial W}

其中\bigodot表示矩阵逐元素乘法。

3.2 代码
def AdaGrad(f, grad, start, lr=0.1, epsilon=1e-8, n=100, plot=True):x = startcache = 0history = [x]for _ in range(n):g = grad(x)cache += g ** 2x = x - lr * g / (np.sqrt(cache) + epsilon)history.append(x)adaGrad_path = np.array(history)if plot:plt.scatter(adaGrad_path, f(adaGrad_path), c='red', s=20, label="AdaGrad")plt.xlabel("x")plt.ylabel("f(x)")plt.legend()plt.title("Optimization Paths of AdaGrad Algorithms")plt.show()return adaGrad_path

四、RMSProp方法

4.1 原理

在AdaGrad方法中存在一个问题,就是随着学习的深入,更新的幅度就会减小,而在实际上,如果无止境地学习,那么更新量就会变成0,完全不再更新。那么,我们考虑对于过去梯度的遗忘设置为逐步的一个过程,在做加法运算时将新梯度的信息更多反映出来,而这种方法就叫做“指数移动平均”,即呈指数式地去减小过去梯度的尺度。那么,公式为:

h\leftarrow \rho \cdot h+(1-\rho )\cdot (\frac{\partial L}{\partial W}\bigodot \frac{\partial L}{\partial W})

W\leftarrow W - \eta \frac{1}{\sqrt{h}}\frac{\partial L}{\partial W}

其中,\bigodot表示矩阵逐元素乘法。

4.2 代码
def RMSProp(f, grad, start, lr=0.1, dr=0.9, epsilon=1e-8, n=100, plot=True):x = startcache = 0history = [x]for _ in range(n):g = grad(x)cache = dr * cache + (1 - dr) * g ** 2  # 指数加权平均x = x - lr * g / (np.sqrt(cache) + epsilon)  # 更新参数history.append(x)RMSProp_path = np.array(history)if plot:plt.scatter(RMSProp_path, f(RMSProp_path), c='red', s=20, label="RMSProp")plt.xlabel("x")plt.ylabel("f(x)")plt.legend()plt.title("Optimization Paths of RMSProp Algorithms")plt.show()return RMSProp_path

五、Adam方法

5.1 原理

Adam算法结合了之前所有的思想,即动量与递减学习率的思想。那么首先在梯度第一矩估计,即在动量项中,存在式子如下:

m_i\leftarrow m_{i-1}*\beta_1 + (1-\beta_1 )*g_t

其中,g_t表示梯度\frac{\partial L}{\partial W}

而梯度第二矩估计(即递减学习率项)的公式为:

v_t \leftarrow \beta_2 * v_{t-1} + (1-\beta_2)*gt\bigodot gt

其中,\bigodot表示矩阵逐元素乘法。

由于初始时刻,即m_0 = 0v_0 = 0时刻,所以需要进行偏差矫正,其公式为:

\hat{m_t} \leftarrow \frac{m_t}{1-\beta_1^t}                               \hat{v_t} \leftarrow \frac{v_t}{1-\beta_2^t}

最后的参数更新就为:

W_{t-1} \leftarrow W_t - \eta\frac{\hat{m_t}}{\sqrt{\hat{v_t}}+\epsilon }

5.2 代码
def adam(f, grad, strat, lr=0.1, beta1=0.9, beta2=0.999, epsilon=1e-8, n=100, plot=True):x = stratm = 0v = 0t = 0history = [x]for _ in range(n):t += 1g = grad(x)m = beta1 * m + (1 - beta1) * gv = beta2 * v + (1 - beta2) * g**2m_hat = m / (1 - beta1**t)v_hat = v / (1 - beta2**t)x = x - lr * m_hat / (np.sqrt(v_hat) + epsilon)history.append(x)adam_path = np.array(history)if plot:plt.scatter(adam_path, f(adam_path), c='red', s=20, label="Adam")plt.xlabel("x")plt.ylabel("f(x)")plt.legend()plt.title("Optimization Paths of Adam Algorithms")plt.show()return adam_path

此上


http://www.ppmy.cn/ops/167061.html

相关文章

【漫话机器学习系列】141.灵敏度(Sensitivity)

灵敏度(Sensitivity)详解 在统计学和机器学习领域,灵敏度(Sensitivity),也称为召回率(Recall),是一种衡量分类模型在检测正例时的能力的重要指标。灵敏度的计算公式如下…

docker安装部署学习

docker安装部署学习 什么是 Docker?如何理解 Docker?1. 容器化技术 vs. 传统虚拟机2. Docker 的核心概念3. Docker 的四大优势 Docker 的应用场景安装 Docker 引擎1. 卸载旧版本(确保环境干净)2. 安装依赖工具3. 添加 Docker 官方…

动手学深度学习:CNN和LeNet

前言 该篇文章记述从零如何实现CNN,以及LeNet对于之前数据集分类的提升效果。 从零实现卷积核 import torch def conv2d(X,k):h,wk.shapeYtorch.zeros((X.shape[0]-h1,X.shape[1]-w1))for i in range(Y.shape[0]):for j in range(Y.shape[1]):Y[i,j](X[i:ih,j:jw…

c++--vector

1.定义vector vector的定义分为四种 (1)vector() ——————无参构造 (2)vector(size_t n,const value_type& val value_type()) ——————构造并初始化n个val (3)vector(const vector& v1) ———————拷贝构造 (4)vector(inputiterator first,inpu…

如何搭建一个安全经济适用的TRS交易平台?

TRS(总收益互换)一种多方参与的投资方式,也是绝对收益互换(total return swap)的一种形式。 它是一种衍生合约,是一种金融衍生品的合约,是指交易双方在协议期间将参照资产的总收益转移给信用保…

基于web的牙医预约管理系统(源码+lw+部署文档+讲解),源码可白嫖!

摘要 信息化时代,各行各业都以网络为基础飞速发展,而医疗服务行业的发展却进展缓慢,传统的医疗服务行业已经逐渐不满足民众的需求,有些还在以线下预约挂号的方式接待病人,为此设计一个牙医预约管理系统很有必要。此类…

Vue中的publicPath释义

publicPath 部署应用包时的基本URL。用法和 webpack 本身的 output.publicPath 一致,但是 Vue CLI 在一些其他地方也需要用到这个值,所以请始终使用 publicPath 而不要直接修改 webpack 的 output.publicPath。 默认情况下,Vue CLI 会假设你…

注意力机制:让AI拥有黄金七秒记忆的魔法--(注意力机制中的Q、K、V)

注意力机制:让AI拥有"黄金七秒记忆"的魔法–(注意力机制中的Q、K、V) 在注意⼒机制中,查询(Query)、键(Key)和值(Value)是三个关键部分。 ■ 查询…