[人工智能]实现神经网络实例

news/2025/3/17 0:18:06/

  1. import numpy as np:导入 NumPy 库,用于数值计算。
  2. 导入 PyTorch 相关库:
    • import torch:导入 PyTorch 库,深度学习框架核心库。
    • from torchvision.datasets import mnist:从torchvision.datasets中导入 MNIST 数据集类,用于下载和加载 MNIST 数据。
    • import torchvision:导入torchvision库,提供了计算机视觉相关的数据集、模型和转换函数等。
    • import torchvision.transforms as transforms:导入数据转换模块,用于对数据进行预处理。
    • from torch.utils.data import DataLoader:导入数据加载器类,方便按批次加载数据。
    • import torch.nn.functional as F:导入 PyTorch 的函数式接口,包含常用的激活函数、损失函数等。
    • import torch.optim as optim:导入优化器模块,用于训练模型时更新参数。
    • import torch.nn as nn:导入神经网络模块,用于构建神经网络模型。
    • from torch.utils.tensorboard import SummaryWriter:导入SummaryWriter类,用于记录训练过程中的数据,方便在 TensorBoard 中可视化。

定义超参数

  1. train_batch_size = 64:定义训练数据的批次大小为 64,即每次训练使用 64 个样本。
  2. test_batch_size = 128:定义测试数据的批次大小为 128。
  3. learning_rate = 0.01:设置学习率为 0.01,控制模型参数更新的步长。
  4. num_epochs = 20:定义训练的轮数为 20,即整个训练数据集将被遍历 20 次。

定义数据预处理函数并加载数据

  1. 定义预处理函数
    • transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) 使用transforms.Compose组合多个数据转换操作。transforms.ToTensor()将数据转换为 PyTorch 的张量格式,transforms.Normalize([0.5], [0.5])对数据进行归一化处理,使数据的均值为 0,标准差为 0.5。
  2. 下载并加载数据
    • train_dataset = mnist.MNIST('../data/', train=True, transform=transform, download=True) 下载并加载训练数据集,数据存储在../data/目录下,设置train=True表示加载训练集,应用定义的预处理函数transform,如果数据不存在则自动下载。
    • test_dataset = mnist.MNIST('../data/', train=False, transform=transform) 加载测试数据集,train=False表示加载测试集,同样应用预处理函数。
  3. 创建数据加载器
    • train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True) 创建训练数据加载器,按照定义的训练批次大小train_batch_size加载数据,shuffle=True表示在每个 epoch 开始时打乱数据顺序。
    • test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False) 创建测试数据加载器,按照测试批次大小test_batch_size加载数据,测试数据一般不需要打乱顺序,所以shuffle=False。

  1. examples = enumerate(test_loader):使用enumerate函数对测试数据加载器test_loader进行包装,使其在迭代时可以同时返回索引和数据。
  2. batch_idx, (example_data, example_targets) = next(examples):通过next函数从examples中获取第一个批次的数据,batch_idx是批次索引,example_data是该批次的图像数据,example_targets是对应的真实标签。
  3. example_data.shape 和 Out[10]: torch.Size([128, 1, 28, 28]):查看第一个批次图像数据的形状,结果表明该批次有 128 个样本,每个样本是 1 通道、大小为 28x28 的图像。

  1. 导入库并设置
    • import matplotlib.pyplot as plt:导入matplotlib库的pyplot模块,用于绘制图像和可视化数据。
    • %matplotlib inline:这是 Jupyter Notebook 中的魔法命令,用于在 Notebook 中直接显示绘制的图像,而不是弹出新窗口。
  2. 获取测试数据
    • examples = enumerate(test_loader):对测试数据加载器test_loader进行枚举,方便获取每个批次的索引和数据。
    • batch_idx, (example_data, example_targets) = next(examples):获取测试数据的第一个批次,example_data包含图像数据,example_targets包含对应的真实标签。
  3. 绘制图像
    • fig = plt.figure():创建一个新的图形对象。
    • for i in range(6):循环 6 次,用于展示 6 张手写数字图像。
    • plt.subplot(2,3,i+1):创建一个 2 行 3 列的子图布局,并指定当前绘制的子图位置。
    • plt.tight_layout():自动调整子图间距,使图像布局更美观。
    • plt.imshow(example_data[i][0], cmap='gray', interpolation='none'):显示第i个样本的图像,cmap='gray'表示以灰度模式显示图像,interpolation='none'表示不进行插值处理。
    • plt.title('Ground Truth: {}'.format(example_targets[i])):设置图像的标题为该样本的真实标签。
    • plt.xticks([]) 和 plt.yticks([]):隐藏图像的横纵坐标轴刻度。

  1. 定义网络类:class Net(nn.Module) 定义了一个名为Net的类,继承自nn.Module,这是 PyTorch 中所有神经网络模型的基类。
  2. __init__方法
    • super(Net, self).__init__():调用父类的构造函数进行初始化。
    • self.flatten = nn.Flatten():定义一个Flatten层,用于将输入的多维图像数据展平为一维向量,方便后续全连接层处理。
    • self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1),nn.BatchNorm1d(n_hidden_1)):使用nn.Sequential将一个线性层nn.Linear和一个批归一化层nn.BatchNorm1d组合在一起,构成网络的第一层。线性层将输入维度in_dim映射到隐藏层维度n_hidden_1 ,批归一化层用于加速训练并提高模型的稳定性。
    • self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2),nn.BatchNorm1d(n_hidden_2)):同理,构成网络的第二层,将维度从n_hidden_1映射到n_hidden_2。
    • self.out = nn.Linear(n_hidden_2, out_dim):定义输出层,将第二层隐藏层的输出维度映射到输出维度out_dim,用于分类任务输出类别概率。
  3. forward方法
    • x=self.flatten(x):对输入x进行展平操作。
    • x = F.relu(self.layer1(x)):将展平后的数据通过第一层,并使用 ReLU 激活函数增加模型的非线性。
    • x = F.relu(self.layer2(x)):通过第二层并再次使用 ReLU 激活函数。
    • x = F.softmax(self.out(x),dim=1):将第二层的输出通过输出层,再使用 Softmax 激活函数将输出转换为概率分布,dim=1指定按行计算,即对每个样本的各个类别输出计算概率。

  1. 外层循环:for epoch in range(num_epochs) 遍历训练的轮数,num_epochs是之前定义的超参数,即整个训练数据集将被重复训练的次数。
  2. 设置训练模式和初始化训练损失
    • model.train():将模型设置为训练模式,此时像BatchNorm和Dropout等层会采用相应的训练行为。
    • train_loss = 0:初始化当前轮次的训练损失为 0。
  3. 动态修改学习率
    • if epoch%5 == 0: 当训练轮数是 5 的倍数时,执行以下操作。
    • optimizer.param_groups[0]['lr']*=0.9 将优化器的学习率乘以 0.9,实现动态调整学习率,随着训练进行逐渐减小学习率,有助于模型更好地收敛。
    • print('学习率:{:.6f}'.format(optimizer.param_groups[0]['lr'])) 打印当前调整后的学习率。
  4. 内层循环:for img, label in train_loader: 遍历训练数据加载器,每次获取一个批次的图像数据img和对应的标签label。
    • 数据移动到设备:img = img.to(device) 和 label = label.to(device) 将图像和标签数据移动到指定的设备(GPU 或 CPU)上。
    • 正向传播:out = model(img) 将图像数据传入模型进行前向传播,得到模型的输出out。
    • 计算损失:loss = criterion(out, label) 使用定义的损失函数criterion(交叉熵损失函数)计算模型输出和真实标签之间的损失值。
    • 反向传播和参数更新
      • optimizer.zero_grad() 在反向传播前将优化器中所有参数的梯度清零,避免梯度累加。
      • loss.backward() 执行反向传播,计算损失值对模型所有可学习参数的梯度。
      • optimizer.step() 根据计算得到的梯度更新模型的参数。
    • 记录损失
      • train_loss += loss.item() 累计当前批次的损失值到当前轮次的总损失中。
      • writer.add_scalar('train-loss', train_loss/len(train_loader), epoch) 将当前轮次的平均损失记录到SummaryWriter中,train_loss/len(train_loader)计算平均损失,epoch为当前轮次。
    • 计算分类准确率
      • _, pred = out.max(1) 从模型输出out中获取每个样本预测概率最高的类别索引pred。
      • num_correct = (pred == label).sum().item() 计算预测正确的样本数量。
      • acc = num_correct / img.shape[0] 计算当前批次的分类准确率。

  1. 初始化评估指标
    • eval_loss = 0:初始化测试集上的损失为 0。
    • eval_acc = 0:初始化测试集上的准确率为 0。
  2. 设置模型为评估模式:model.eval() 将模型设置为评估模式,此时像BatchNorm和Dropout等层会采用与训练模式不同的行为,以确保评估结果的准确性。
  3. 遍历测试数据
    • for img, label in test_loader: 循环遍历测试数据加载器test_loader,每次获取一个批次的图像数据img和标签label。
    • 数据处理和移动
      • img = img.to(device) 和 label = label.to(device) 将图像和标签数据移动到指定设备上。
      • img = img.view(img.size(0), -1) 将图像数据展平,以适应模型输入要求(这里假设模型输入需要一维向量形式)。
    • 正向传播和计算损失
      • out = model(img) 将图像数据传入模型进行前向传播,得到模型的输出out。
      • loss = criterion(out, label) 使用交叉熵损失函数计算模型输出和真实标签之间的损失值。
    • 记录评估损失:eval_loss += loss.item() 累计当前批次的损失值到测试集的总损失中。
    • 计算评估准确率
      • _, pred = out.max(1) 从模型输出out中获取每个样本预测概率最高的类别索引pred。
      • num_correct = (pred == label).sum().item() 计算预测正确的样本数量。
      • acc = num_correct / img.shape[0] 计算当前批次的准确率。
      • eval_acc += acc 累计当前批次的准确率。
  4. 记录评估结果
    • eval_losses.append(eval_loss / len(test_loader)) 将测试集的平均损失添加到eval_losses列表中。
    • eval_accs.append(eval_acc / len(test_loader)) 将测试集的平均准确率添加到eval_accs列表中。

  1. plt.title('train loss'):设置绘制图形的标题为 “train loss”,用于表明该图展示的是训练损失相关内容。
  2. plt.plot(np.arange(len(losses)), losses):使用plot函数绘制折线图。np.arange(len(losses))生成一个从 0 到losses列表长度减 1 的数组,作为横坐标,表示训练的轮次;losses列表存储了每一轮训练的平均损失值,作为纵坐标。这样就将训练轮次和对应的损失值进行了可视化展示。
  3. plt.legend(['Train Loss'], loc='upper right'):为图形添加图例,图例内容为 “Train Loss”,并将其放置在图形的右上角(loc='upper right'),方便读者识别曲线代表的含义。

http://www.ppmy.cn/news/1579692.html

相关文章

虚幻基础:动画层接口

文章目录 动画层:动画图表中的函数接口:名字,没有实现。动画层接口:由动画蓝图实现1.动画层可直接调用实现功能2.动画层接口必须安装3.动画层默认使用本身实现4.动画层也可使用其他动画蓝图实现,但必须在角色蓝图中关联…

蓝桥杯刷题——第十五届蓝桥杯大赛软件赛省赛C/C++ 大学 B 组

一、0握手问题 - 蓝桥云课 算法代码&#xff1a; #include <iostream> using namespace std; int main() {int sum0;for(int i49;i>7;i--)sumi;cout<<sum<<endl;return 0; } 直接暴力&#xff0c;题意很清晰&#xff0c;累加即可。 二、0小球反弹 - 蓝…

我的创作纪念日:730天的技术写作之旅

我的创作纪念日&#xff1a;730天的技术写作之旅 机缘 从一篇案例分析开始 2023年3月13日&#xff0c;我写下了第一篇技术博客《软考高级-系统分析师-案例分析-系统维护与设计模式》。那时的初心很简单&#xff1a; 沉淀实战经验——在备考软考系统分析师时&#xff0c;发现…

Docker-compose一键部署Zabbix监控平台

1. 环境准备 1.1 系统版本 [rootmonitor ~]# cat /etc/redhat-release CentOS Linux release 7.9.2009 (Core) [rootmonitor ~]# uname -a Linux monitor 3.10.0-1160.108.1.el7.x86_64 #1 SMP Thu Jan 25 16:17:31 UTC 2024 x86_64 x86_64 x86_64 GNU/Linux 1.2 Docker版本…

深搜专题11:分数字

描述 将整数N分成K个整数的和且每个数大于等于A小于等于B&#xff0c;求有多少种分法 注意&#xff1a;5 0 0 0 和 0 5 0 0被视为一种方法 输入描述 输入只有一行&#xff0c;分别输入N,K,A,B (所有数字均为不大于30的非负整数) 输出描述 输出只有一行&#xff0c;即多少种分法…

英语学习(GitHub学到的分享)

【英语语法&#xff1a;https://github.com/hzpt-inet-club/english-note】 【离谱的英语学习指南&#xff1a;https://github.com/byoungd/English-level-up-tips/tree/master】 【很喜欢文中的一句话&#xff1a;如果我轻轻松松的学习&#xff0c;生活的幸福指数会提高很多…

SSL 原理及实验

引言 为了实现远程办公或者远程客户访问内网的资源 &#xff08;1&#xff09;回顾历史&#xff1a; 起初先出现SSL(Secure Sockets Layer&#xff09;&#xff0d;安全套接层协议。 美国网景Netscape公司1994年研发&#xff0c;介于传输层TCP协议和应用层协议之间的一种协议…

【SpringMVC】常用注解:@ModelAttribute

1.作用 该注解是在SpringMVC4.3版本后新加入的。它可以修饰方法和参数。出现在方法上&#xff0c;表示当前方法会在控制器的方法之前执行。它可以修饰 没有返回值的方法&#xff0c;也可以修饰没有返回值的方法。它修饰参数&#xff0c;获取指定 的数据给参数赋值。 当表单提…