李沐老师动手深度学习pytorch版本的读取fashion_mnist数据并用AlexNet模型训练,其中修改为利用本地的数据集训练

news/2024/10/19 23:24:36/

李沐老师的d2l.load_data_fashion_mnist里面没有root参数,所以只会下载,不能利用本地的fashion_mnist数据。所以我使用torchvision 的datasets里面FashionMNIST方法,又由于李沐老师此处是利用AlexNet模型来训练fashion_mnist数据,所以我们需要调整数据集的大小

导入必要的库和模块

import torch 
from torch import nn 
from d2l import torch as d2l
import numpy as np  
from torch.utils.data import Dataset, DataLoader  
import torchvision
import torchvision.transforms as transforms

转换数据

由于我们需要在加载数据同时定义数据转换,可以使用transforms.Compose来组合多个转换操作,使用Resize方法来调整图片大小,使其可以符合AlexNet的输入尺寸

 transform = transforms.Compose([  transforms.Resize((224, 224)),  # 将图片调整为224x224的大小  transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0., 1.]  # 可以选择添加transforms.Normalize进行进一步的标准化操作  
])

加载本地的fashion_mnist数据集

注意root参数,PyTorch 期望 FashionMNIST 数据集在指定的根目录下以特定的方式组织,确保下载好的文件存在于你指定的 root 路径下的 FashionMNIST/raw/ 文件夹中。可以参考我这篇博客:pytorch加载本地文件的root设置

通过train=False/True来设置训练集和测试集
通过设置download=True/False来确定找不到本地数据集的时候是否从网络下载
通过transform 指定特征和标签转换

#加载本地数据集,注意root参数
minist_train = torchvision.datasets.FashionMNIST(root='F:\\deeplearning\\fashion_mnist',train=True,download=False,transform=transform)
minist_test = torchvision.datasets.FashionMNIST(root='F:\\deeplearning\\fashion_mnist',train=False,download=False,transform=transform)print(type(minist_train))
print(len(minist_train),len(minist_test))

创建数据加载器,方便批次化处理

# 创建数据加载器  
batch_size = 128  
train_iter = DataLoader(minist_train, batch_size=batch_size, shuffle=True)  
test_iter = DataLoader(minist_test, batch_size=batch_size, shuffle=False) 

大致实现AlexNet网络架构

李沐老师在pytorch版本的动手深度学习中实现的模型

net = nn.Sequential(
# 这里使用一个11*11的更大窗口来捕捉对象。
# 同时,步幅为4,以减少输出的高度和宽度。
# 另外,输出通道的数目远大于LeNet
nn.Conv2d(1, 96, kernel_size=11, stride=4, padding=1), nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2),
# 减小卷积窗口,使用填充为2来使得输入与输出的高和宽一致,且增大输出通道数
nn.Conv2d(96, 256, kernel_size=5, padding=2), nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2),
# 使用三个连续的卷积层和较小的卷积窗口。
# 除了最后的卷积层,输出通道的数量进一步增加。
# 在前两个卷积层之后,汇聚层不用于减少输入的高度和宽度
nn.Conv2d(256, 384, kernel_size=3, padding=1), nn.ReLU(),
nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.ReLU(),
nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Flatten(),
# 这里,全连接层的输出数量是LeNet中的好几倍。使用dropout层来减轻过拟合
nn.Linear(6400, 4096), nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(4096, 4096), nn.ReLU(),
nn.Dropout(p=0.5),
# 最后是输出层。由于这里使用Fashion-MNIST,所以用类别数为10,而非论文中的1000
nn.Linear(4096, 10))

设置学习率,epoch并在GPU上训练

lr, num_epochs = 0.01, 10
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

结果如下

在这里插入图片描述


http://www.ppmy.cn/news/1509398.html

相关文章

【区块链教程】如何使用自动化脚本创建小狐狸地址

在很多场景下,不管是撸毛也好,批量操作也好,都需要使用到大量的qianbao地址。 如何一键创建,也成为了很多人想学的技术。 创建地址的自动化脚本可以极大地简化在区块链开发和测试中的管理流程。以下是一个基本的流程和示例脚本&am…

msgqueue.hpp队列模块

目录 一.MsgQueue模块介绍 二.MsgQueue类的实现 成员变量 构造函数与析构函数 成员函数 参数设置函数 setArgs 参数获取函数 getArgs 三.MsgQueueMapper类的实现 成员变量 构造函数 成员函数 创建表格函数 createTable 删除表格函数 dropTable 插入数据函数 inse…

Stable Diffusion绘画 | 提示词基础原理

提示词之间使用英文逗号“,”分割 例如:1girl,black long hair, sitting in office 提示词之间允许换行 但换行时,记得在结尾添加英文逗号“,”来进行区分 权重默认为1,越靠前权重越高 每个提示词自身的权重默认值为1,但越靠…

数据预处理和探索性数据分析(上)

目录 数据预处理 数据清洗 处理缺失值: 异常值检测与处理: 类别特征编码: 特征工程 创建新特征: 特征缩放: 探索性数据分析 (EDA) 使用Matplotlib进行可视化 绘制直方图: 绘制箱线图&#xff1…

【网络编程】组播的实现(C语言,linux,Ubuntu)

组播 1> 组播也是实现一对多的通信方式,对于广播而言,网络需要对每个消息进行复制转发,会占用大量的带宽,导致网络拥塞 2> 组播可以实现小范围的数据传播:将需要接收数据的接收端加入多播组,发送端…

LeetCode 45. 跳跃游戏 II 题解

引言 在LeetCode的算法题库中,“跳跃游戏 II”是一个经典的贪心算法问题。这个问题不仅考验了我们对数组操作的理解,还锻炼了我们如何利用贪心策略来优化问题求解。本文将详细解析这个问题,并提供Java语言的解决方案。 问题描述 给定一个非…

Java语言程序设计基础篇_编程练习题16.22(播放、循环播放和停止播放一个音频剪辑)

题目:16.22(播放、循环播放和停止播放一个音频剪辑) 编写一个满足下面要求的程序: 使用AudioClip获取一个音频文件,该文件存放在类目录下。放置三个标记为Play、Loop和Stop的按钮,如图16-46a所示。单击Pla…

F.Enchanted

https://codeforces.com/gym/105139/problem/F24湖北省赛F 看了一下前面两种操作,做法不是很明显 后面两种操作,一看就是可持久化线段树,单点修改,版本复制 接下来解决前面的两种操作 第一个操作 两个相同的合成一个新的(33-&…