多元回归梯度下降算法实现(SGD优化)(数据集随机生成)

news/2024/10/31 3:26:55/

多元回归梯度下降算法实现(SGD优化)(数据集随机生成)

下面就是代码。其实博主做了很多实验,实验效果好不好,跟数据集的质量,跟学习率的选择,SGD 优化器batch的选择都很重要。

下面看一下代码叭:


import torch
from torch.autograd import Variable
from torch.utils import data
X =torch.normal(0,100,(100,4))
w=torch.tensor([1,2,3,4])Y =torch.matmul(X, w.type(dtype=torch.float))  + torch.normal(0, 1, (100, ))
print(Y)
Y=Y.reshape((-1, 1))
print(Y.type())
print(w.type())
print(X.type())
#将X,Y转成200 batch大小,1维度的数据def load_array(data_arrays, batch_size, is_train=True):dataset = data.TensorDataset(*data_arrays)return data.DataLoader(dataset, batch_size, shuffle=is_train)data_iter = load_array((X, Y), 32)model = torch.nn.Sequential(torch.nn.Linear(4, 1))w=torch.tensor([1,0.1,0.2,0.3])
b=torch.randn(1)
w=Variable(w,requires_grad=True)b=Variable(b,requires_grad=True)
print(w)def loss_function(w,x,y,choice,b):if choice==1:return torch.abs(torch.sum(w@x)+b-y)else:#    print("fdasf:",torch.sum(w@x),y)#   print(torch.pow(torch.sum(w@x)-y,2))return torch.sum(w@x)-y
index=0n=1000
batch=32
learning_rating=0.000001
def SGD(batch):grad=Variable(torch.tensor([0.0]),requires_grad=True)for j in range(batch):try:#  print(w,X[index],Y[index],b,)#    print(loss_function(w,X[index],Y[index],b,2))#  print(torch.sum(w@X[index]),Y[index])grad=(torch.sum(w@X[index])-Y[index])*(-1)*X[index]+gradloss=loss_function(w,X[index],Y[index],2,b)index=index+1except:index=0return  grad/batch,loss/batchwhile n:n=n-1grad,loss=SGD(batch)w.data=w.data+learning_rating*grad*w.data# print("b",b)print("w:",w)print("loss:",loss)#  b.data=b.data-(learning_rating*b.grad.data)#   print("b",b)

http://www.ppmy.cn/news/617543.html

相关文章

gitlab 配置QQ邮箱

gitlab 配置QQ邮箱 gitlab版本官方文档邮箱厂商需要修改的配置文件修改的配置内容调试控制台正确测试结果[rootxxxxxxxxxxxxxx gitlab]# gitlab-rails consoleGitLab: 11.2.3 (06cbee3) GitLab Shell: 8.1.1 postgresql: 9.6.8 gitlab版本 我安装的是ee版11.2.3,跟…

qq协议 0825 和 0836 udp 登录包解析

qq协议 0825 和 0836 udp 登录包解析 参考使用工具:概念解释udp报文解析0825 udp 发送包报文原始数据:解析 0825 返回包原始数据解析 0836 发送包原始数据解析 参考 0825包参考: https://www.cnblogs.com/mRRRR/p/5288931.html 虽然是2016年的, 但是里面的结构大体还是不变 参…

查询QQ会员账号信息API接口

接口地址: https://api.hackeus.cn/api/qqvip 请求协议: HTTP、HTTPS 请求方式: GET/POST 返回格式: JSON 请求示例: https://api.hackeus.cn/api/qqvip?api_key您的apiKey&qq598765401&skeyHackApi 请求…

ros 如何禁止qq

ros是个非常好用的路由设备,我用它实现了公司的带宽管理,有效保障了视频会议和业务应用的带宽。但是ros没有完善的应用管理功能,一直想用ros实现禁止qq的功能,网上查了查,内容很乱,于是自己动手抓包试验。我…

【matlab之QQ图】

文章目录 1.QQ图原理:2.效果图:3.代码: 1.QQ图原理: 数据中一串数目的每个点都是该数据的某分位点,把这些点的(称为样本分位数点)和相应的理论上的分位数配对做出散点图,如果该数据服从正态分布…

Shell 脚本配置发送QQ邮件

文章目录 方法一1进入QQ邮箱网页界面客服端2用root用户,执行脚本 方法二: 方法一 1进入QQ邮箱网页界面客服端 ①点击设置 》点击账户 ② 开启SMTP服务 》生成授权码 2用root用户,执行脚本 #!/bin/bash. /etc/init.d/functionsif [ $# -…

AndroidQQ登录

AndroidQQ登录 一、注册腾讯开放平台账号 1. 在腾讯开放平台注册账号获取开发者资格:http://open.qq.com/ 2. 注册完成后点击右上角的管理中心,点击右下角的创建应用 3. 创建应用之后就会获取到APPID和APPKEY 4. 下载SDKjar包:http://wi…

QQ传输协议分析

一、 实验目的: 在虚拟机下NAT模式下通过Wireshark抓包,分析QQ的传输模式。了解QQ在传输信息过程中用到的协议。分析在Nat模式下,信息传输的穿透性。 二、 实验环境: Win7 专业版32位(在虚拟机里面)。 Win…