李沐深度学习-多层感知机从零开始

news/2025/2/5 4:06:28/

!!!梯度的产生是由于反向传播,在自定义从零开始编写代码时,第一次反向传播前应该对params参数的梯度进行判断

import torch
import numpy as np
import torch.utils.data as Data
import torchvision.datasets
import torchvision.transforms as transforms
import syssys.path.append("路径")
import d2lzh_pytorch as d2l'''
--------------------------------------------------获取和读取数据
'''
batch_size = 256
train_mnist = torchvision.datasets.FashionMNIST(root='路径',download=True, train=True, transform=transforms.ToTensor())
test_mnist = torchvision.datasets.FashionMNIST(root='路径',download=True, train=False, transform=transforms.ToTensor())
train_iter = Data.DataLoader(train_mnist, batch_size=batch_size, shuffle=True)
test_iter = Data.DataLoader(test_mnist, batch_size=batch_size, shuffle=False)'''
--------------------------------------------------定义模型参数
'''
num_inputs = 784
num_outputs = 10
num_hidden = 256
# 有几个隐藏层就要设置几个参数,简洁实现中,linear网络会自动配置初始参数,自己可以使用init.normal_()设置参数初始值
w1 = torch.tensor(np.random.normal(0, 0.01, (num_inputs, num_hidden)), dtype=torch.float)
b1 = torch.zeros(num_hidden, dtype=torch.float)
w2 = torch.tensor(np.random.normal(0, 0.1, (num_hidden, num_outputs)), dtype=torch.float)
b2 = torch.zeros(num_outputs, dtype=torch.float)
params = [w1, b1, w2, b2]
for param in params:param.requires_grad_(requires_grad=True)
'''
---------------------------------------------定义激活函数
'''def relu(X):return torch.max(input=X, other=torch.tensor(0.0))'''
---------------------------------------------------定义模型
'''# 使用view函数将输入的样本转换成inputs特征数大小的图像
def net(X):X = X.view((-1, num_inputs))H = relu(torch.matmul(X, w1) + b1)  # torch.mm(X, w1) + b1得到隐藏层输出# 对隐藏层变量进行激活函数变换,然后作为下一个全连接层的输入# 第一层不是隐藏层,直接线性计算,隐藏层输出作为输出层输入的时候,对隐藏层进行非线性变换,然后传入输入层return torch.matmul(H, w2) + b2  # 隐藏层作为输出层的输入   n层layer有最多n-2个激活函数'''
-----------------------------------------------------定义损失函数
'''
loss = torch.nn.CrossEntropyLoss()  # 包含了softmax运算和交叉熵运算'''
------------------------------------------------------softmax操作,用于训练模型中训练集准确率调用
'''def softmax(X):X_exp = X.exp()  # 幂指数化partition = X_exp.sum(dim=1, keepdim=True)  # 求和每行的元素值return X_exp / partition  # 做比值得预测概率'''
----------------------------------------------------测试集准确率函数,训练模型中测试集准确率调用
'''def evaluate_accuracy(test_data):acc_num, num = 0.0, 0for X, y in test_data:  # X,y分别是一个元组acc_num += (softmax(net(X)).argmax(dim=1) == y).float().sum().item()num += y.shape[0]return acc_num / num'''
------------------------------------------------------训练模型
'''
num_epochs, lr = 5, 100def train():for epoch in range(num_epochs):train_acc, train_l, test_acc, n, num = 0.0, 0.0, 0.0, 0, 0for X, y in train_iter:  #l = loss(net(X), y)  # CrossEntropyLoss 函数已经是对一个批次内所有样本的平均损失计算了if params[0].grad is not None:  # 第一次训练迭代前是没有梯度产生的,梯度是由于反向传播才产生的for param in params:  # 参数梯度清零param.grad.data.zero_()l.backward()  # 反向传播d2l.sgd(params, lr, batch_size)  # 梯度下降操作train_l += l.item()# net(X)返回每个样本各个类别的预测值,有n个样本返回train_acc += (softmax(net(X)).argmax(dim=1) == y).float().sum().item()  # 累加预测正确个数n += y.shape[0]num += 1test_acc = evaluate_accuracy(test_iter)print(f'epoch %d, loss %.4f, train_acc %.3f, test_acc %.3f'% (epoch + 1, train_l / num, train_acc / n, test_acc))train()

http://www.ppmy.cn/news/1327210.html

相关文章

让uniapp小程序支持多色图标icon:iconfont-tools-cli

前景: uniapp开发小程序项目时,对于iconfont多色图标无法直接支持;若将多色icon下载引入项目则必须关注包体,若将图标放在oss或者哪里管理,加载又是一个问题,因此大多采用iconfont-tools工具,但…

李沐深度学习-d2lzh_pytorch模块实现

d2lzh_pytorch 模块 import random import torch import matplotlib_inline from matplotlib import pyplot as plt import torchvision import torchvision.transforms as transforms import torchvision.datasets import sys from collections import OrderedDict# --------…

【C/Python】用GTK实现多文档窗体程序

一、用C语言 在GTK(GIMP Toolkit)中实现多文档接口(MDI)程序可以使用多种方法。GTK本身并没有提供专用的MDI窗口小部件,但可以使用标签页(Notebook)或多个窗口(Window)来…

TCP缓存(C++)

系统为每个 socket 创建了发送缓冲区和接收缓冲区,应用程序调用 send()/write()函数发送数据的 时候,内核把数据从应用进程拷贝 socket 的发送缓冲区中;应用程序调用 recv()/read()函数接收数据的时候,内核把数据从 socket 的接收…

接口测试 02 -- JMeter入门到实战

前言 JM eter毕竟是做压测的工具,自动化这块还是有缺陷。 如果公司做一些简单的接口自动化,可以考虑使用JMeter快速完成,如果想做完善的接口自动化体系,建议还是基于Python来做。 为什么学习接口测试要先从JMeter开始?…

企业微信上传临时素材errcode:44001,errmsg:empty media data

企业微信,上传临时素材,报错: {“errcode”:44001,“errmsg”:“empty media data [logid:]”}, 开发语言C# 重点代码: formData.Headers.ContentType new MediaTypeHeaderValue(“application/octet-stream”); 解…

一台手机用4年多,国产手机从态度傲慢到跪求消费者换机

分析机构trendforce公布的数据指出,中国消费者的换机周期已延长到51个月,面对消费者对国产手机用脚投票,如今国产手机企业开始采取多方举措,祈求消费者买手机,市场的变化促使国产手机不得不改变态度。 2010年国产手机刚…

第五章 漏洞评估 - 《骇客修成秘籍》

第五章 漏洞评估 作者:Julian Paul Assange 简介 扫描和识别目标的漏洞通常被渗透测试者看做无聊的任务之一。但是,它也是最重要的任务之一。这也应该被当做为你的家庭作业。就像在学校那样,家庭作业和小测验的设计目的是让你熟练通过考试。…