【算法】梯度下降

ops/2024/9/20 9:17:22/ 标签: 算法

一、引言

        梯度下降算法(Gradient Descent)是一种一阶迭代优化算法,用于求解最小化目标函数的问题,广泛应用于机器学习和人工智能中的参数优化。

         用于优化问题的迭代算法,尤其在机器学习和深度学习中广泛用于最小化损失函数,以找到模型参数的最佳值。其基本思想是根据当前参数的梯度(即损失函数相对于参数的导数)逐步更新参数,从而使损失函数值逐渐减小。

二、算法原理

        梯度下降算法的核心原理是利用负梯度方向作为搜索方向,因为在多元函数的某一点处,函数值沿着负梯度方向下降最快。算法步骤包括:

        选择初始点:在函数定义域内任选一个初始点。

        计算梯度:在当前点计算目标函数的梯度(导数)。

        参数更新:根据梯度和一个预先设定的学习率来更新参数。

        迭代:重复步骤2和3,直到满足停止条件。

三、数据结构

梯度下降算法中涉及的数据结构主要包括:

  • 参数向量:存储模型参数的向量。
  • 梯度向量:存储目标函数关于参数的导数的向量。

四、算法使用场景

梯度下降算法适用于以下场景:

  • 机器学习:在训练过程中优化模型参数。
  • 深度学习:用于神经网络的权重调整。
  • 经济学:求解资源分配问题。
  • 线性回归:用于拟合线性模型,并通过最小化均方误差(MSE)来找到最佳参数。
  • 逻辑回归:用于分类任务,通过最小化交叉熵损失来优化模型。

  • 神经网络:用于复杂模型的训练,通过反向传播和梯度下降优化网络的权重。

五、算法实现

Python实现的简单梯度下降算法示例:

python

def gradient_descent(gradient_func, start_point, learning_rate, epochs):parameters = start_pointfor _ in range(epochs):grad = gradient_func(parameters)parameters -= learning_rate * gradreturn parameters# 假设的目标函数和梯度函数
def objective_function(w):return w**2 + 10 * w + 20def gradient(w):return 2 * w + 10# 初始参数,学习率,迭代次数
start_point = 20
learning_rate = 0.1
epochs = 10# 执行梯度下降
optimal_params = gradient_descent(gradient, start_point, learning_rate, epochs)
print(f"Optimal parameters: {optimal_params}")

六、其他同类算法对比

与梯度下降算法相比,其他优化算法包括:

  • 牛顿法:使用二阶导数(Hessian矩阵)来寻找最优点,收敛速度快但计算复杂。
  • 随机梯度下降(SGD):每次迭代使用一个样本的梯度来更新参数,适合大数据集。由于其更新频繁且噪声较大,可能导致收敛不稳定,但在实践中能加速收敛。

  • 批量梯度下降(Batch Gradient Descent):使用所有训练数据来计算梯度,收敛稳定但可能较慢,特别是在数据集非常大时,因为每次迭代都需要计算整个数据集的梯度。

  • 小批量梯度下降(Mini-batch Gradient Descent):结合了批量梯度下降和随机梯度下降的优点。将数据集分成小批量,每次迭代使用一个小批量来计算梯度。这种方法在稳定性和效率之间取得了平衡。

  • 动量法(Momentum):在梯度下降中引入了过去梯度的加权平均,以加速收敛并减少震荡。公式中引入了动量项,使得更新不仅依赖当前梯度,还依赖于之前的更新方向。

  • 自适应学习率方法(如AdaGrad, RMSprop, Adam)

    • AdaGrad:根据每个参数的历史梯度的平方和调整学习率,适合稀疏数据。
    • RMSprop:对梯度平方的移动平均进行平滑,以调整学习率,能够处理非平稳目标。
    • Adam:结合了动量法和RMSprop的优点,使用动量和学习率的自适应调整,以加速收敛和提高鲁棒性。

七、多语言实现

使用Java、C++和Go语言实现的梯度下降算法的简单示例:

Java

import org.apache.commons.math3.linear.*;public class GradientDescent {public static RealVector gradientDescent(RealMatrix X, RealVector y, RealVector theta, double learningRate, int iterations) {int m = y.getDimension();RealVector costHistory = new ArrayRealVector(iterations);for (int i = 0; i < iterations; i++) {RealVector predictions = X.operate(theta);RealVector errors = predictions.subtract(y);RealVector gradients = X.transpose().operate(errors).mapDivide(m);theta = theta.subtract(gradients.mapMultiply(learningRate));double cost = computeCost(X, y, theta);costHistory.setEntry(i, cost);}return theta;}public static double computeCost(RealMatrix X, RealVector y, RealVector theta) {RealVector predictions = X.operate(theta);RealVector errors = predictions.subtract(y);double cost = errors.dotProduct(errors) / (2 * y.getDimension());return cost;}public static void main(String[] args) {// Sample datadouble[][] data = {{1, 1}, {1, 2}, {1, 3}};double[] labels = {1, 2, 3};RealMatrix X = MatrixUtils.createRealMatrix(data);RealVector y = new ArrayRealVector(labels);RealVector theta = new ArrayRealVector(new double[]{0.1, 0.2});double learningRate = 0.01;int iterations = 1000;RealVector optimizedTheta = gradientDescent(X, y, theta, learningRate, iterations);System.out.println("Optimized Theta: " + optimizedTheta);}
}

C++

#include <iostream>
#include <Eigen/Dense>using namespace Eigen;double computeCost(const MatrixXd& X, const VectorXd& y, const VectorXd& theta) {int m = y.size();VectorXd predictions = X * theta;VectorXd error = predictions - y;double cost = (1.0 / (2 * m)) * error.dot(error);return cost;
}VectorXd gradientDescent(const MatrixXd& X, const VectorXd& y, VectorXd theta, double learningRate, int iterations) {int m = y.size();VectorXd costHistory(iterations);for (int i = 0; i < iterations; ++i) {VectorXd predictions = X * theta;VectorXd error = predictions - y;VectorXd gradients = (1.0 / m) * X.transpose() * error;theta -= learningRate * gradients;costHistory(i) = computeCost(X, y, theta);}return theta;
}int main() {// Sample dataMatrixXd X(3, 2);X << 1, 1, 1, 2, 1, 3;VectorXd y(3);y << 1, 2, 3;VectorXd theta(2);theta << 0.1, 0.2;double learningRate = 0.01;int iterations = 1000;VectorXd optimizedTheta = gradientDescent(X, y, theta, learningRate, iterations);std::cout << "Optimized Theta:\n" << optimizedTheta << std::endl;return 0;
}

Go

package mainimport ("fmt""gonum.org/v1/gonum/mat"
)// Compute Cost Function
func computeCost(X, y *mat.Dense, theta *mat.VecDense) float64 {m, _ := X.Dims()predictions := mat.NewVecDense(m, nil)predictions.MulVec(X, theta)error := mat.NewVecDense(m, nil)error.SubVec(predictions, y)cost := mat.Dot(error, error) / (2 * float64(m))return cost
}// Gradient Descent
func gradientDescent(X, y *mat.Dense, theta *mat.VecDense, learningRate float64, iterations int) (*mat.VecDense, []float64) {m, _ := X.Dims()costHistory := make([]float64, iterations)for i := 0; i < iterations; i++ {predictions := mat.NewVecDense(m, nil)predictions.MulVec(X, theta)error := mat.NewVecDense(m, nil)error.SubVec(predictions, y)gradients := mat.NewVecDense(X.Caps().Cols, nil)gradients.MulVec(X.T(), error)gradients.Scale(1/float64(m), gradients)theta.SubVec(theta, gradients.Scale(learningRate, gradients))costHistory[i] = computeCost(X, y, theta)}return theta, costHistory
}func main() {// Sample dataX := mat.NewDense(3, 2, []float64{1, 1, 1, 2, 1, 3})y := mat.NewVecDense(3, []float64{1, 2, 3})theta := mat.NewVecDense(2, []float64{0.1, 0.2})learningRate := 0.01iterations := 1000optimizedTheta, costHistory := gradientDescent(X, y, theta, learningRate, iterations)fmt.Println("Optimized Theta:\n", optimizedTheta)fmt.Println("Final Cost:", costHistory[iterations-1])
}

 八、实际的服务应用场景代码框架

        示例代码框架,演示如何使用实现的线性回归模型在 Flask(Python)中构建一个简单的服务。
Python Flask 应用

from flask import Flask, request, jsonify
import numpy as npapp = Flask(__name__)class LinearRegression:
def __init__(self, learning_rate=0.01, n_iterations=1000):
self.learning_rate = learning_rate
self.n_iterations = n_iterations
self.weights = None
self.bias = Nonedef fit(self, X, y):
n_samples, n_features = X.shape
self.weights = np.zeros(n_features)
self.bias = 0for _ in range(self.n_iterations):
y_predicted = np.dot(X, self.weights) + self.bias
dw = (1 / n_samples) * np.dot(X.T, (y_predicted - y))
db = (1 / n_samples) * np.sum(y_predicted - y)self.weights -= self.learning_rate * dw
self.bias -= self.learning_rate * dbdef predict(self, X):
return np.dot(X, self.weights) + self.bias# 创建模型实例
model = LinearRegression()@app.route('/train', methods=['POST'])
def train():
data = request.json
X = np.array(data['X'])
y = np.array(data['y'])
model.fit(X, y)
return jsonify({'message': 'Model trained successfully'}), 200@app.route('/predict', methods=['POST'])
def predict():
data = request.json
X = np.array(data['X'])
predictions = model.predict(X).tolist()
return jsonify({'predictions': predictions}), 200if __name__ == '__main__':
app.run(debug=True)


        安装 Flask:pip install Flask
        运行 Flask 应用:python app.py
        发送训练请求:curl -X POST http://127.0.0.1:5000/train -H "Content-Type: application/json" -d '{"X": [[1], [2], [3]], "y": [2, 3, 4]}'
        发送预测请求:curl -X POST http://127.0.0.1:5000/predict -H "Content-Type: application/json" -d '{"X": [[4], [5]]}'
 


http://www.ppmy.cn/ops/93794.html

相关文章

2024有哪些好用的图纸加密软件,10款图纸加密软件排行榜

在信息安全愈加重要的今天&#xff0c;企业和个人都越来越注重图纸及设计文件的安全性。无论是工业设计、建筑设计还是其他涉及机密信息的图纸文件&#xff0c;加密软件都成为了保护知识产权的关键工具。下面&#xff0c;我们将介绍2024年最值得关注的十款图纸加密软件&#xf…

TypeScript学习第十三篇 - 泛型

在编译期间不确定变量的类型&#xff0c;在调用时&#xff0c;由开发者指定具体的类型。 1. 如何给arg参数和函数指定类型&#xff1f; function identity(arg){return arg; }identity(1) identity(jack) identity(true) identity([]) identity(null)定义的时候&#xff0c;无…

C#调用c++的dll方法,动态调用c++dll的方法

文章目录 一、创建c的dll1.新建项目2.删除vs自建的.cpp和.h文件3.新建Algorithm.h和Algorithm.cpp4.编译c1.编译2.解决报错3.再次编译可以看到已经成功。4.查看成功输出的dll。 二、创建c#项目1.创建一个console控制台程序。2.把dll拷贝到c#生成的程序根目录。3.在c#的program.…

【Pyspark-驯化】一文搞懂Pyspark中对json数据处理使用技巧:get_json_object

【Pyspark-驯化】一文搞懂Pyspark中对json数据处理使用技巧&#xff1a;get_json_object 本次修炼方法请往下查看 &#x1f308; 欢迎莅临我的个人主页 &#x1f448;这里是我工作、学习、实践 IT领域、真诚分享 踩坑集合&#xff0c;智慧小天地&#xff01; &#x1f387; …

使用WebSocket实现一个简易的聊天室

我这里的框架是SpringBoot 首先&#xff0c;我们要有一个前端页面 <!DOCTYPE html> <html xmlns:th"http://www.thymeleaf.org"xmlns:layout"http://www.ultraq.net.nz/web/thymeleaf/layout"layout:decorate"layout"> <head&g…

论文分享|MLLMs中多种模态(图像/视频/音频/语音)的tokenizer梳理

本文旨在对任意模态输入-任意模态输出 (X2X) 的LLM的编解码方式进行简单梳理&#xff0c;同时总结一些代表性工作。 注&#xff1a;图像代表Image&#xff0c;视频代表Video&#xff08;不含声音&#xff09;&#xff0c;音频代表 Audio/Music&#xff0c;语音代表Speech 各种…

eBPF编程指南(一):eBPF初体验

1 什么是EBPF&#xff1f; EBPF是一种可以让程序员在内核态执行自己的程序的机制&#xff0c;但是&#xff0c;为了安全起见&#xff0c;无法像内核模块一样随意调用内核的函数&#xff0c;只能调用一些bpf提前定义好的函数。为了让内核执行程序员自己的代码&#xff0c;需要指…

字符串值提取工具-03-java 调用 groovy

值提取系列 值提取系列 字符串值提取工具-01-概览 字符串值提取工具-02-java 调用 js 字符串值提取工具-03-java 调用 groovy 字符串值提取工具-04-java 调用 java? Janino 编译工具 字符串值提取工具-05-java 调用 shell 字符串值提取工具-06-java 调用 python 字符串…

kali-linux 常用命令大集合(目录、文件查看与编辑,登录、电源、帮助等相关命令详解)

目录 目录查看-ls 帮助命令 帮助命令&#xff1a;whatis 帮助命令&#xff1a;help 帮助命令&#xff1a;man 帮助命令&#xff1a;info 登录命令 登录命令&#xff1a;login 登录命令&#xff1a;last 登录命令&#xff1a;exit 切换用户&#xff1a;su/sudo 命令-…

现在画原型都用什么工具?

现在画原型时&#xff0c;Axure是一款广泛使用的工具&#xff0c;尤其在需要高度交互性和逻辑复杂性的项目中&#xff0c;如企业级应用、大型软件项目等&#xff0c;Axure更是首选。 https://ffhog9.axshare.com https://1zvcwx.axshare.com/start.html 以下是对Axure及其优…

全网最最最详细的haproxy详解!!!

1 什么是负载均衡 负载均衡&#xff08;Load Balancing&#xff09;是一种将网络请求或工作负载分散到多个服务器或计算机资源上的技术&#xff0c;以实现优化资源使用、提高系统吞吐量、增强数据冗余和故障容错能力、以及减少响应时间的目的。在分布式系统、云计算环境、Web服…

[C#]实现GRPC通讯的服务端和客户端实例

最近要做两个软件之间消息的通讯&#xff0c;学习了一下GRPC框架的通讯。根据官方资料做了一个实例。 官方资料请参考&#xff1a;Create a .NET Core gRPC client and server in ASP.NET Core | Microsoft Learn 开发平台&#xff1a;Visual Studio 2022 开发前提条件&#x…

微服务部分面试问题(面试篇)

分布式事务 CAP定理 分布式系统有三个指标&#xff1a; Consistency&#xff08;一致性&#xff09; Availability&#xff08;可用性&#xff09; Partition tolerance &#xff08;分区容错性&#xff09; 它们的第一个字母分别是 C、A、P。Eric Brewer认为任何分布式系…

8.6 MySQL

[rootmysql ~]# sed -i $aexport PATH/usr/local/mysql/bin/:$PATH /etc/profile //加到环境变量 [rootmysql ~]# source /etc/profile //使配置环境生效配置开机自启 [rootmysql ~]# chkconfig --list //列举 注&#xff1a;该输出结果只显…

网络编程----TCP/IP协议

使用TCP/IP协议实现客户端和服务器端进行通信: 1.服务器端(test1.c): #include <sys/socket.h> #include <sys/types.h> #include <arpa/inet.h> #include <stdio.h> #include <unistd.h>// 创建服务器端 int main() {//1.创建套接字int serfd…

计算机毕业设计 校园新闻管理系统 Java+SpringBoot+Vue 前后端分离 文档报告 代码讲解 安装调试

&#x1f34a;作者&#xff1a;计算机编程-吉哥 &#x1f34a;简介&#xff1a;专业从事JavaWeb程序开发&#xff0c;微信小程序开发&#xff0c;定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事&#xff0c;生活就是快乐的。 &#x1f34a;心愿&#xff1a;点…

2024年最新版小程序云开发数据模型的开通步骤,开始开发微信小程序前的准备工作,认真看完奥!

小程序官方又改版了&#xff0c;搞得石头哥不得不紧急的再新出一版&#xff0c;教大家开通最新版的数据模型。官方既然主推数据模型&#xff0c;那我们就先看看看新版的数据模型到底是什么。 一&#xff0c;什么是数据模型 数据模型是什么 数据模型是一个用于组织和管理数据的…

第七节 流编辑器sed(stream editor)(7.2)

三,模式空间中的编辑操作 3,1,地址定界 地址定界示例说明不写地址定界表示对文件所有行进行处理num1,num21,3或者1,$对文件的1-3行进行处理或者1-$(表示文件的最后一行)num1,N1,3对文件的num1行和之后n行进行处理first~step1~2对文件的1,3,5,7,…的行内容进行处理/pattern//^…

RCE安全漏洞 贷齐乐系统Sql注入漏洞

远程代码执行介绍在phpstudy上创建网站在本地数据库中创建数据库--ctf&#xff0c;并创建users表&#xff0c;往表中插入数据 远程代码执行&#xff08;Remote Code Execution&#xff0c;RCE&#xff09;是一种严重的网络安全漏洞&#xff0c;它允许攻击者通过输入恶意代码直…

P3957 [NOIP2017 普及组] 跳房子(青春版代码)

[NOIP2017 普及组] 跳房子 - 洛谷 核心思路 单调队列优化dp 顺序 先让合法答案入队 再删去越界答案 判断非空 后 求 答案 一个答案合法 当且仅当 l < dis < r 记 调了n久&#xff0c;找题解调。 竟发现几乎没有用 STL deque 的。 故写了个青春版题解。 AC 代码…