深度学习:完整的模型训练流程

news/2024/11/29 17:10:11/

深度学习:完整的模型训练流程

为了确保我们提供一个彻底和清晰的指导,让我们深入分析在model.pytrain.py文件中定义的模型训练和验证流程。以下部分将详细讨论模型结构的定义、数据的加载与预处理、训练参数的配置、训练与测试循环,以及模型的保存和性能监控。此外,我们将通过具体代码示例,详尽解释每个环节的执行逻辑和目的,从而确保您能够有效地理解并应用这些步骤以优化您的深度学习项目。

1. 模型结构定义(model.py

目的和结构

model.py中定义的My_Network类基于PyTorch的torch.nn.Module。这个类的设计使模型能够集成并利用PyTorch提供的多种深度学习功能,从而实现有效的图像分类。

class My_Network(nn.Module):def __init__(self):super(My_Network, self).__init__()# 创建序列模型self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),  # 第一层卷积层nn.ReLU(),                  # 激活函数nn.MaxPool2d(2),            # 池化层nn.Conv2d(32, 32, 5, 1, 2), # 第二层卷积层nn.ReLU(),                  nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2), # 第三层卷积层nn.ReLU(),nn.MaxPool2d(2),nn.Flatten(),               # 展平层nn.Linear(64*4*4, 64),      # 全连接层nn.ReLU(),nn.Linear(64, 10)           # 输出层)def forward(self, x):return self.model(x)
注释解释
  • 卷积层:使用3通道输入,利用5x5的卷积核来提取特征,步长为1,填充为2以保持图像尺寸。
  • ReLU激活函数:引入非线性,帮助网络学习复杂的模式。
  • 池化层:使用2x2的窗口减小特征维度,同时保留重要的特征信息,减少计算量并防止过拟合。
  • 展平层与全连接层:将二维特征图转换为一维特征向量,通过全连接层进行分类。

2. 数据加载与预处理

实现细节

使用PyTorch的torchvision.datasets库来加载和预处理CIFAR-10数据集。

train_data = torchvision.datasets.CIFAR10(root="../dataset", train=True, transform=torchvision.transforms.ToTensor(), download=True)
test_data = torchvision.datasets.CIFAR10(root="../dataset", train=False, transform=torchvision.transforms.ToTensor(), download=True)
预处理功能
  • ToTensor():将图片数据转换为Tensor,并归一化到[0,1]范围。
  • Normalize():通常在此步骤中加入标准化处理,但在这里简化为基本的Tensor转换。

3. 训练和测试循环

配置和流程

设置训练环境,包括数据加载器、损失函数和优化器,定义训练和测试过程。

# 设置数据加载器
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)# 定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(my_network.parameters(), lr=0.01)
训练和测试代码
for epoch in range(10):  # 进行10个训练周期my_network.train()  # 设置为训练模式for imgs, targets in train_loader:outputs = my_network(imgs)loss = loss_fn(outputs, targets)optimizer.zero_grad()loss.backward()optimizer.step()# 测试模式my_network.eval()total_accuracy = 0with torch.no_grad():for imgs, targets in test_loader:outputs = my_network(imgs)total_accuracy += (outputs.argmax(1) == targets).sum().item()print(f"Epoch {epoch+1}: Accuracy = {total_accuracy / len(test_data)}")

4. 模型保存和日志记录

保存策略和性能监控

在每个训练周期后保存模型的状态,并使用TensorBoard来记录关键的训练和测试指标。

torch.save(my_network.state_dict(), f"my_network_epoch_{epoch}.pth")
writer.add_scalar("Loss/train", loss.item(), epoch)
writer.add_scalar("Accuracy/test", total_accuracy / len(test_data), epoch)
解释
  • 模型保存:定期保存训练后的模型状态,以便进行未来的训练或评估。
  • 性能监控:使用TensorBoard记录训练损失和测试准确率,帮助监控模型的学习进度和性能。

通过这种详尽的分析,每个步骤的实现和逻辑都被清晰地展示和解释,确保模型训练和验证的每个关键考量都被适当地处理。这为深入理解和优化深度学习模型提供了坚实的基础。


http://www.ppmy.cn/news/1550949.html

相关文章

如何通过PHP爬虫模拟表单提交,抓取隐藏数据

引言 在网络爬虫技术中,模拟表单提交是一项常见的任务,特别是对于需要动态请求才能获取的隐藏数据。在电商双十一、双十二等促销活动期间,商品信息的实时获取尤为重要,特别是针对不断变化的价格和库存动态。为了满足这种需求&…

泷羽sec-蓝队基础之网络七层杀伤链(上) 学习笔记

声明! 学习视频来自B站up主 **泷羽sec** 有兴趣的师傅可以关注一下,如涉及侵权马上删除文章,笔记只是方便各位师傅的学习和探讨,文章所提到的网站以及内容,只做学习交流,其他均与本人以及泷羽sec团队无关&a…

界面控件DevExpress Blazor UI v24.1亮点:全新的渲染引擎和项目模板等

DevExpress Blazor UI组件使用了C#为Blazor Server和Blazor WebAssembly创建高影响力的用户体验,这个UI自建库提供了一套全面的原生Blazor UI组件(包括Pivot Grid、调度程序、图表、数据编辑器和报表等)。 DevExpress Blazor控件目前已经升级…

算力100问☞第28问:智算中心的软件基础设施有哪些?

1、智算操作系统 作为智算中心的核心软件,智算操作系统负责对计算、存储、网络等硬件资源进行统一管理和调度,实现资源的灵活分配与高效利用。例如,九章云极 DataCanvas 的 Alaya NeW 智算操作系统,能够纳管智算资源、输出智算服…

nginx动静分离和rewrite重写和https和keepalived

动静分离,通过中间件将动态请求和静态请求分离,可以减少不必要的消耗,同时减少请求延迟 动静分离只有好处:动静分离后,即使动态资源不可用,但静态资源不受影响单台实现动静分离 1.部署java yum install ja…

电池建模 003- Behavioral battery mode行为电池模型入门学习

1、概要 库文件位置: Simscape / Battery / Cells 行为电池模型 电池块表示一个简单的电池模型。您可以选择暴露充电输出端口和电池的热端口。 要测量电池的内部电荷水平,在主菜单中,将“暴露充电测量端口”设置为“是”。此操作会暴露一个额外的物理信…

【Linux课程学习】:《简易版shell实现和原理》 《哪些命令可以让子进程执行,哪些命令让shell执行(内键命令)?为什么?》

🎁个人主页:我们的五年 🔍系列专栏:Linux课程学习 🌷追光的人,终会万丈光芒 🎉欢迎大家点赞👍评论📝收藏⭐文章 目录 打印命令行提示符(PrintCommandLin…

Apache-maven在Windows中的安装配置及Eclipse中的使用

Apache Maven 是一个自动化项目管理工具,用于构建,报告和文档的项目管理工具。以下是在不同操作系统上安装和配置 Maven 的基本步骤: 安装 Maven 下载 Maven: apache-maven-3.9.9下载地址,也可访问 Apache Maven 官方网站 下载最…