【动手学习深度学习--逐行代码解析合集】02线性回归的简洁实现

news/2024/11/8 15:07:34/

【动手学习深度学习】逐行代码解析合集

02线性回归的简洁实现


视频链接:B站-动手学习深度学习
课程主页:https://courses.d2l.ai/zh-v2/
教材:https://zh-v2.d2l.ai/

线性回归的简洁实现-代码

以下代码是在PyCharm中运行的

import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2l"====================1、生成数据集===================="
def synthetic_data(w, b, num_examples):  #@save"""生成y=Xw+b+噪声"""X = torch.normal(0, 1, (num_examples, len(w)))  # 生成均值为0,方差为1,1000行2列的矩阵Xy = torch.matmul(X, w) + by += torch.normal(0, 0.01, y.shape)  #噪声矩阵的size需要和y相同# reshape(-1,1)代表将二维数组重整为一个一列的二维数组# reshape(1,-1)代表将二维数组重整为一个一行的二维数组return X, y.reshape((-1, 1))  # reshape(-1,1)中的-1代表无意义true_w = torch.tensor([2, -3.4])
true_b = 4.2
# 得到的features为【1000,2】,labels为1列
features, labels = d2l.synthetic_data(true_w, true_b, 1000)"====================2、读取数据集===================="
# 布尔值is_train表示是否希望数据迭代器对象在每个迭代周期内打乱数据。
def load_array(data_arrays, batch_size, is_train=True):  #@save"""构造一个PyTorch数据迭代器"""# TensorDataset用来对 tensor 进行打包,包装成dataset。类似 python 中的 zip 功能dataset = data.TensorDataset(*data_arrays)return data.DataLoader(dataset, batch_size, shuffle=is_train)  #将数据划分为小批量batch_size = 10
data_iter = load_array((features, labels), batch_size)
# 使用iter构造Python迭代器,并使用next从迭代器中获取第一项。
# print(next(iter(data_iter)))"====================3、初始化参数与定义函数===================="
# nn是神经网络的缩写
from torch import nn
# 全连接层在Linear类中定义,第一个参数指定输入特征形状,第二个参数指定输出特征形状
net = nn.Sequential(nn.Linear(2, 1))"初始化模型参数"
# net[0]选择网络中的第一个图层
# 使用weight.data和bias.data方法访问参数。
# 使用替换方法normal_和fill_来重写参数值。
net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)"定义损失函数  均方误差"
# 计算均方误差使用的是MSELoss类,也称为平方L2范数
loss = nn.MSELoss()"定义优化算法  小批量梯度下降算法"
# 优化器就是需要根据网络反向传播的梯度信息来更新网络的参数,以起到降低loss函数计算值的作用
# 在PyTorch中optim模块中实现了小批量梯度下降算法
# 从net.parameters()中获得需要优化的参数:权重w和偏置b
trainer = torch.optim.SGD(net.parameters(), lr=0.03)"====================4、训练===================="
# 对于每一个小批量,我们会进行以下步骤:
# 1、通过调用net(X)生成预测并计算损失l(前向传播)。
# 2、通过进行反向传播来计算梯度。
# 3、通过调用优化器来更新模型参数。
num_epochs = 3
for epoch in range(num_epochs):for X, y in data_iter:l = loss(net(X) ,y)trainer.zero_grad()   # 梯度清零l.backward()  # 对损失函数进行反向传播,计算梯度trainer.step()  # 更新学习率# 再算一遍loss只是为了输出结果l = loss(net(features), labels)print(f'epoch {epoch + 1}, loss {l:f}')
'''
输出:
epoch 1, loss 0.000181
epoch 2, loss 0.000109
epoch 3, loss 0.000109
'''
w = net[0].weight.data
print('w的估计误差:', true_w - w.reshape(true_w.shape))
b = net[0].bias.data
print('b的估计误差:', true_b - b)
'''
输出:
w的估计误差: tensor([-0.0009, -0.0002])
b的估计误差: tensor([0.0010])
'''

http://www.ppmy.cn/news/673963.html

相关文章

win10推送_微软 Win10 最稳版本 2004 正式版推送!最低配置要求汇总,全新 UI 虚拟桌面,支持几乎所有 CPU...

5月28日消息 ,微软开始推送2020 Windows 10更新五月版(2004版本)系统更新,现在包括微软Windows 10官网工具、MSDN订阅网站都已经发布了Windows 10版本2004的官方ISO镜像下载。大家可以按需求下载和安装、收藏使用。 微软2020 Windows 10更新五月版(版本2…

微软Windows 11正式发布!看看怎么下载安装包或升级!

Windows 11 的发布日期是今天,即 2021 年 10 月 5 日,用户可以准备下载世界上使用最广泛的桌面操作系统的最新更新。同时,Office 2021 正式版也于今天一起发布,以下是您需要了解的有关更新下载和 Windows 11 兼容计算机的所有信息…

用PE安装win11系统

https://cloud.189.cn/t/7ne6zeiQvI3a (访问码:d12x) win11系统下载 https://cloud.189.cn/t/EV73EfzmER32 (访问码:pmb6) pe制作软件下载 win11p2p下载: ed2k://|file|zh-cn_windows_11_business_editions_updated_march_2022_x64_dvd_7df6eae1.iso|5582235648|…

IC设计行业都有哪些不错的公司(外企篇)

当越来越多的同学选择IC修真院之后,咨询的问题就从最开始的该怎么入行IC设计,慢慢变成了该怎么选择offer,该怎么挑选公司? 非科班的同学本就对IC行业知之甚少,更别提具体的公司状况,他们不清楚为什么不推荐…

linux内核大版本演进,Linux史上最重大的版本: 内核 5.8 今天发布了

在最新的更新中,Linux内核5.8版本引入了众多的驱动程序支持、安全改进和优化。 Linux之父Linus Torvalds确实在Linux 5.8 RC1版本中提到了这一点: 不过,5.8仍然是最好的,尽管它没有任何一项脱颖而出的功能特性。是的,驱…

Windows 11 安卓子系统安装教程

Windows 11 上周开始引入了对安卓子系统的支持,不过目前只对 Beta 通道用户开放,而正式版系统和 Dev 通道的用户还无法使用,好在 Windows 一向比较开放,我们可以通过「偷渡」的方式自行安装,极客之选今天就把 Win11 安…

linux 历史重大更新整理

3.8 CPU热插拔支持;改进ACPI电源管理;改善XFS文件系统;支持64位ARMv8/AArch64;放弃支持旧的i386处理器,减少内耗复杂度;Video 4 Linux 2驱动支持 DMA-BUF;在某些工作负荷下减少物理内存占用&am…

Win11是否支持老硬件?老电脑能装win11吗?

Win11支持与时代相符的混合工作环境,侧重于在灵活多变的体验中提高用户的工作效率。但是有不少小伙伴们反映说不知道win11是否支持老硬件?今天给朋友们介绍win11是否支持老硬件,还有不清楚小伙伴一起学习一下吧。 Win11是否支持老硬件介绍&am…