深度学习【迭代梯度下降法求解线性回归】

devtools/2025/3/21 22:12:53/


梯度下降法

梯度下降法是一种常用迭代方法,其目的是让输入向量找到一个合适的迭代方向,使得输出值能达到局部最小值。在拟合线性回归方程时,我们把损失函数视为以参数向量为输入的函数,找到其梯度下降的方向并进行迭代,就能找到最优的参数值。

1.计算对于给定的线性模型 (y = wx + b) 的均方误差(MSE)。它接受截距 (b)、斜率 (w) 和点集 (points),然后遍历所有点,计算每个点的预测值,与真实值之差的平方和,最后返回平均误差。

2.更新w和b

3.多次迭代后,得到最优的w和b,也就是y=wx+b这个模型对于给定数据集的最优

这里给定数据集:100个(x,y)

import torch
import numpy as np#计算给定点集的线性回归的误差  y = wx + b
def compute_error_for_line_given_points(b,w,points):total_error = 0for i in range(len(points)):x = points[i,0]y = points[i,1]total_error += (y - (w*x + b))**2return total_error/float(len(points))#梯度下降法求解线性回归  w = w - learning_rate * w_gradient, b = b - learning_rate * b_gradient
def step_gradient(b_current,w_current,points,learning_rate):b_gradient = 0w_gradient = 0n = float(len(points))for i in range(len(points)):x = points[i,0]y = points[i,1]b_gradient += -(2/n) * (y - ((w_current*x) + b_current))w_gradient += -(2/n) * x * (y - ((w_current*x) + b_current))new_b = b_current - (learning_rate * b_gradient)new_w = w_current - (learning_rate * w_gradient)return [new_b,new_w]#迭代梯度下降法求解线性回归
def gradient_descent_runner(points,starting_b,starting_w,learning_rate,num_iterations):b = starting_bw = starting_wfor i in range(num_iterations):b,w = step_gradient(b,w,points,learning_rate)return [b,w]def run():points = np.genfromtxt('data.csv', delimiter=',')learning_rate = 0.0001initial_b = 0initial_w = 0num_iterations = 1000print("Starting gradient descent at b = {0}, w = {1}, error = {2}".format(initial_b,initial_w,compute_error_for_line_given_points(initial_b,initial_w,points)))print("Running...")[b,w] = gradient_descent_runner(points,initial_b,initial_w,learning_rate,num_iterations)print("After {0} iterations b = {1}, w = {2}, error = {3}".format(num_iterations,b,w,compute_error_for_line_given_points(b,w,points)))if __name__ == '__main__':run()

执行结果:


http://www.ppmy.cn/devtools/169001.html

相关文章

Linux的Shell编程

一、什么是Shell 1、为什么要学习Shell Linux运维工程师在进行服务器集群管理时,需要编写Shell程序来进行服务器管理。 对于JavaEE和Python程序员来说,工作的需要。Boss会要求你编写一些Shell脚本进行程序或者是服务器的维护,比如编写一个…

Socket 、WebSocket、Socket.IO详细对比

WebSocket、Socket 和 Socket.IO 是网络通信中常用的技术,它们在功能、使用场景和实现方式上有明显的异同点。以下是它们的详细对比: 1. Socket 定义 Socket 是一个通用的网络编程接口,用于在网络上实现进程间通信(IPC&#xff0…

cool-admin-midway 使用腾讯云cos上传图片

说明:在使用cool-admin这个低代码平台时,发现官方的cos上传插件有问题,总是报错 substring,故自己找解决方案,修改本地的upload方法改为云端上传。 解决方案: 安装腾讯云cos的nodeJS SDK pnpm i cos-node…

CMS漏洞-WordPress篇

一.姿势一:后台修改模板拿WebShell 1.使用以下命令开启docker cd /www/wwwroot / vulhub / wordpress / pwnscriptum docker - compose up - d 如果发现不能开启,可以检查版本和端口 2.访问网址登录成功后 外观 👉编辑 👉404.…

Python第六章04:列表操作练习题

# 列表常用功能练习题 """ 有一个列表,内容是:[21,25,21,23,22,20],记录一批学生的年龄请通过列表的功能(方法),对齐进行: 1.定义这个列表,并用变量接收它 2.追加一个数字31&…

【SpringCloud】Eureka、LoadBalancer和Nacos

🔥个人主页: 中草药 🔥专栏:【中间件】企业级中间件剖析 一、微服务 单体架构 单体架构是一种传统的软件架构方式,它将一个应用程序的所有功能模块(如用户认证、订单处理、数据存储等)都打包在…

Web3网络生态中数据保护合规性分析

Web3网络生态中数据保护合规性分析 在这个信息爆炸的时代,Web3网络生态以其独特的去中心化特性,逐渐成为数据交互和价值转移的新平台。Web3,也被称为去中心化互联网,其核心理念是将数据的控制权归还给用户,实现数据的…

ubuntu24.04安装VMware Tools

虚拟机创建ubuntu24.04,安装VMware Tools, sudo apt update sudo apt install open-vm-tools sudo reboot 之后可以创建共享文件夹用于主机和虚拟机之间传输文件。 在虚拟机-设置-选项-共享文件夹,中点选“总是启用”并添加共享路径和在…