深度学习之线性回归模型

devtools/2024/9/23 14:29:50/

看完了李沐老师的动手学深度中线性回归,对单层神经网络有了进一步的认识,针对老师上课的代码,我进行了复现并对代码进行了详细的注释。

# 线性回归从0开始实现import random
import torch
from d2l import torch as d2l# 根据带有噪声的线性模型构造一个人工数据集,使用线性模型参数w=[2,-3.4].T,b=4.2和噪声项epslion生成数据集及标签def synthetic_data(w,b,num_examples):# 生成 y=Xw+b+噪声# 均值为0方差为1的正态随机数,样本大小为num_xampls(行数),列数与w的维度相同X = torch.normal(0,1,(num_examples,len(w)))# 真实值y = torch.matmul(X,w) + b # 加入噪音y +=torch.normal(0,0.01,y.shape)# 返回特征和标签并把标签变为二维张量(1000,1)return X,y.reshape((-1,1))true_w = torch.tensor([2,-3.4])
true_b = 4.2# 生成真实特征和标签
features, labels = synthetic_data(true_w,true_b,1000)# 散点图展示d2l.set_figsize()# detach将张量转变为array数组进行画散点图
d2l.plt.scatter(features[:,1].detach().numpy(),labels.detach().numpy(),1)# 定义一个data_iter函数用于接收批量大小、特征矩阵和标签向量作为输入,生成batch_size的小批量def data_iter(batch_size,features,labels):# 样本数量num_examples = len(features)# 生成每个样本的索引indices = list(range(num_examples))# 这些样本是随机读取的,没有特定顺序(打乱索引顺序)random.shuffle(indices)for i in range(0,num_examples,batch_size):# 构造一个batch_indices即随机索引batch_indices = torch.tensor(indices[i:min(i+batch_size,num_examples)])# 返回随机特征和标签,注意不可以使用returnyield features[batch_indices], labels[batch_indices]# return batch_indicesbatch_size = 10# 查看构造的数据
for X,y in data_iter(batch_size,features,labels):print(X,'\n',y)break# 定义初始化参数# 初始化权重:均值为0方差为1的正太随机数,由于需要计算梯度,因此将requires_grad设置为True
w = torch.normal(0,0.01,size=(2,1),requires_grad=True)# 初始化偏置为0,由于也要计算梯度因此requires_grad也要设置为True
b = torch.zeros(1,requires_grad=True)# 定义线性模型def linreg(X,w,b):return torch.matmul(X,w)+b # 定义损失函数,采用均方损失函数def squared_loss(y_hat,y):return (y_hat-y.reshape(y_hat.shape))**2/2# 定义优化算法,采用随机梯度下降法
# params表示一个张量列表,包含w和b,lr是学习率
def sgd(params,lr,batch_size):"""小批量随机梯度下降"""# 表示更新的时候不参与梯度计算with torch.no_grad():for param in params:# 随机梯度下降法,注意要除以均值param -= lr * param.grad / batch_size# 将梯度设为0,确保下一次计算与上一次结果无关,理论可以看看前面的求导视频param.grad.zero_()# 学习率
lr = 0.03# 表示将数据扫三遍
num_epochs = 3# 网络结构(线性模型)
net = linreg# 损失函数
loss = squared_lossfor epoch in range(num_epochs):for X,y in data_iter(batch_size,features,labels):# 计算X和y的小批量损失l = loss(net(X,w,b),y)# 求和之后计算梯度l.sum().backward()# 梯度下降法更新参数sgd([w,b],lr,batch_size)# 使用最后得到的w,b去预测全部数据并计算与真实标签之间的误差with torch.no_grad():train_l = loss(net(features,w,b),labels)# 展示每一次的平均误差print(f'epoch {epoch + 1}, loss{float(train_l.mean()):f}')# 计算真实参数与训练得到的参数之间的误差print(f'w的估计误差:{true_w-w.reshape(true_w.shape)}')
print(f'b的估计误差:{true_b-b}')


http://www.ppmy.cn/devtools/42230.html

相关文章

【scikit-learn005】支持向量机(Support Vector Machines, SVM)ML模型实战及经验总结(更新中)

1.一直以来想写下基于scikit-learn训练AI算法的系列文章,作为较火的机器学习框架,也是日常项目开发中常用的一款工具,最近刚好挤时间梳理、总结下这块儿的知识体系。 2.熟悉、梳理、总结下scikit-learn框架支持向量机(Support Vec…

Vue3:分类管理综合案例实现

综合案例 实现分类管理功能 路由 在main.js中引入router 访问根路径’/后跳转到布局容器 加载布局容器后重定向到’/nav/manage’ 加载我们需要的组件 这样可以在布局容器中切换功能模块时,只对需要修改的组件进行重新加载 const router createRouter({history: create…

MPLAB X IDE编译attiny1616工程报错却无报错信息

MPLAB X IDE(XC-8编译器)编译报错,无具体错误内容,仅显示需要xc-8 pro的警告。 内存占用率显示为81%,未超标。 原因:软件使用了microchip的bootloader功能。应用程序起始地址(也是bootloader结束地址)设置错…

将Flutter程序打包为ios应用并进行安装使用

如果直接执行flutter build ios: Building com.example.myTimeApp for device (ios-release)...════════════════════════════════════════════════════════════════════════════════No vali…

RTMP低延迟推流

人总是需要压力才能进步, 最近有个项目, 需要我在RK3568上, 推流到公网, 最大程度的降低延迟. 废话不多说, 先直接看效果: 数据经过WiFi发送到Inenter的SRS服务器, 再通过网页拉流的. 因为是打金任务, 所以逼了自己一把, 把RTMP推流好好捋一遍. 先说说任务目标, 首先是MPP编码…

如何加密电脑文件夹?重要文件夹怎么加密?

文件夹可以帮助我们管理电脑数据,而文件夹并不具有安全保护功能,很容易导致数据泄露。因此,我们需要加密保护电脑文件夹。那么,如何加密电脑文件夹呢?下面我们就来了解一下。 EFS加密 EFS加密是Windows提供的数据加密…

万亿国债即将发行,普通人能分一杯羹吗?信任为何提前亮起红灯?

财政部最新公告揭示:《2024年国债发行计划》正式出炉,涵盖一系列长期至超长期限的国债,涵盖20年、30年及50年期限。这一消息瞬间点燃了市场的讨论热情,激发了民众对于国家债务投资的兴趣与疑虑。 一、超长国债,你准备好…

python怎么安装matplotlib

1、登陆官方网址“https://pypi.org/project/matplotlib/#description”,下载安装包。 2、选择合适的安装包,下载下来。 3、将安装包放置到python交互命令窗口的当前目录下。 4、打开windows的命令行窗口,通过"pip install"这个命令…