【代码pycharm】动手学深度学习v2-08 线性回归 + 基础优化算法

devtools/2024/11/27 3:51:58/

课程链接

线性回归的从零开始实现

import random
import torch
from d2l import torch as d2l# 人造数据集
def synthetic_data(w,b,num_examples):X=torch.normal(0,1,(num_examples,len(w)))y=torch.matmul(X,w)+by+=torch.normal(0,0.01,y.shape) # 加入噪声return X,y.reshape(-1,1) # y从行向量转为列向量
true_w=torch.tensor([2,-3.4])
true_b=4.2
features,labels=synthetic_data(true_w,true_b,1000)print('features:',features[0],'\nlabels:',labels[0])#绘图展示
d2l.set_figsize()
d2l.plt.scatter(features[:,1].detach().numpy(),labels.detach().numpy(),1);
d2l.plt.show()
# 读数据集
def data_iter(batch_size,features,labels):num_examples=len(features) #看一下有多少个样本indices=list(range(num_examples))# 生成0-999的元组,然后将range()返回的可迭代对象转为一个列表random.shuffle(indices)# 将序列的所有元素随机排序(打乱下标)for i in range(0,num_examples,batch_size): #从0到最后,每次取batch_size个大小batch_indices=torch.tensor(indices[i:min(i+batch_size,num_examples)]) #超出样本个数没有拿满的话取最小值yield features[batch_indices],labels[batch_indices]batch_size=10
for X,y in data_iter(batch_size,features,labels):#给一些样本标号,每一次随机从里面选取b个样本返回print(X,'\n',y)break#定义初始化模型参数
w=torch.normal(0,0.01,size=(2,1),requires_grad=True)
b=torch.zeros(1,requires_grad=True)
#定义模型
def linreg(X,w,b):return torch.matmul(X,w)+b#定义损失函数
def squared_loss(y_hat,y): #均方损失return (y_hat-y.reshape(y_hat.shape))**2/2
#定义优化算法
def sgd(params,lr,batch_size):with torch.no_grad():for param in params:param-=lr*param.grad/batch_sizeparam.grad.zero_()#训练过程
lr=0.03
num_epochs=3
net=linreg
loss=squared_loss
for epoch in range(num_epochs):for X,y in data_iter(batch_size,features,labels):l=loss(net(X,w,b),y)l.sum().backward()sgd([w,b],lr,batch_size)with torch.no_grad():train_l=loss(net(features,w,b),labels)print(f'epoch{epoch+1},loss{float(train_l.mean()):f}')#比较真实参数和训练得来的参数评估训练的成功程度
print(f'w的估计误差:{true_w-w.reshape(true_w.shape)}')
print(f'b的估计误差:{true_b-b}')

运行结果
在这里插入图片描述
在这里插入图片描述

线性回归的简洁实现

import random
import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2l
from torch import nn
#使用框架生成数据集
true_w=torch.tensor([2,-3.4])
true_b=4.2
features,labels=d2l.synthetic_data(true_w,true_b,1000)
#使用框架现有的API读取数据
def load_array(data_arrays,batch_size,is_train=True):dataset=data.TensorDataset(*data_arrays)return data.DataLoader(dataset,batch_size,shuffle=is_train)
batch_size=10
data_iter=load_array((features,labels),batch_size)
print(next(iter(data_iter)))
# 模型的定义
#使用框架预定义好的层
net=nn.Sequential(nn.Linear(2,1))
# 初始化模型参数
net[0].weight.data.normal_(0,0.01)
net[0].bias.data.fill_(0)
loss=nn.MSELoss()
trainer=torch.optim.SGD(net.parameters(),lr=0.03)
#训练
num_epochs=3
for epoch in range(num_epochs):for X,y in data_iter:l=loss(net(X),y)trainer.zero_grad()l.backward()trainer.step()l=loss(net(features),labels)print(f'epoch{epoch+1},loss{l:f}')

在这里插入图片描述


http://www.ppmy.cn/devtools/137298.html

相关文章

初学 flutter 环境变量配置

一、jdk(jdk11) 1)配置环境变量 新增:JAVA_HOMEC:\Program Files\Java\jdk-11 //你的jdk目录 在path新增:%JAVA_HOME%\bin2)验证是否配置成功(cmd运行命令) java java -version …

Java 字符串截取详解:常见场景与方法

Java 字符串截取详解:常见场景与方法 在 Java 开发中,截取字符串是一个非常常见的操作,无论是获取文件名还是提取某些特定内容。本文详细介绍了截取字符串最后一位及其他常见截取操作的多种方法,帮助开发者快速上手。 目录 截取…

Linux tcpdump 详解教程

简介 tcpdump 是一款在 Linux 平台上广泛使用的网络抓包工具。它可以捕获整个 TCP/IP 协议族的数据包,并支持对网络层、协议、主机、端口等进行过滤。tcpdump 提供了强大的过滤功能,允许使用 and、or、not 等逻辑语句来筛选数据包,非常适合用…

springboot基于Android的华蓥山旅游导航系统

摘 要 华蓥山旅游导航系统是一款专为华蓥山景区设计的智能导览应用,旨在为用户提供便捷的旅游信息服务。该系统通过整合华蓥山的地理信息、景点介绍、交通状况等数据,实现了对景区的全面覆盖。用户可以通过该系统获取实时的旅游资讯、交流论坛、地图等。…

vue3 reactive响应式实现源码

Vue 3 的 reactive 是基于 JavaScript 的 Proxy 实现的,因此它通过代理机制来拦截对象的操作,从而实现响应式数据的追踪。下面是 Vue 3 的 reactive 源码简化版。 Vue 3 reactive 源码简化版 首先,我们需要了解 reactive 是如何工作的&…

道品智能科技移动式水肥一体机:农业灌溉施肥的革新之选

在现代农业的发展进程中,科技的力量正日益凸显。其中,移动式水肥一体机以其独特的可移动性、智能化以及实现水肥一体化的卓越性能,成为了农业领域的一颗璀璨新星。它不仅改变了传统的农业灌溉施肥方式,更为农业生产带来了高效、精…

linux僵尸线程清理

文章目录 1.cleanup_zombies.sh脚本2.terminate_zombie_parents.sh:3.监控僵尸进程monitor_zombies.sh:4. 执行权限5.定时处理6.使用go执行 1.cleanup_zombies.sh脚本 #!/bin/bashecho "检测并尝试清理僵尸进程..."# 查找所有僵尸进程及其父进…

SpringBoot开发——Maven多模块工程最佳实践及详细示例

文章目录 一、前言二、Maven多模块工程的最佳实践1、项目结构清晰2、依赖管理统一3、插件配置统一4、版本控制一致5、模块间通信简化 三、详细示例1、项目结构2、父模块(parent)的pom.xml文件3、子模块(module-api)的pom.xml文件4…