基于Python的人工智能应用案例系列(14):Fashion MNIST图像分类CNN

server/2024/10/19 3:25:01/

        在这一篇文章中,我们将使用PyTorch来实现卷积神经网络(CNN),对Fashion MNIST数据集进行图像分类任务。Fashion MNIST数据集是MNIST的升级版,包含各种服装、鞋类和配饰的灰度图像,非常适合作为深度学习的入门数据集。

数据集概述

        Fashion MNIST数据集由60,000张训练图像和10,000张测试图像组成,每张图像的尺寸为28x28像素,且为灰度图像。数据集共包含10个类别,分别为:

  1. T-shirt/top(T恤/上衣)
  2. Trouser(裤子)
  3. Pullover(套衫)
  4. Dress(连衣裙)
  5. Coat(外套)
  6. Sandal(凉鞋)
  7. Shirt(衬衫)
  8. Sneaker(运动鞋)
  9. Bag(包)
  10. Ankle boot(短靴)

        接下来,我们将从数据的加载开始,一步步构建并训练CNN模型。

1. ETL(提取、转换、加载)

        首先,我们通过torchvision中的datasets模块加载Fashion MNIST数据集,并使用DataLoader进行批量处理。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt# 数据转换
transform = transforms.ToTensor()# 加载训练和测试数据集
train_data = datasets.FashionMNIST(root='../Data', train=True, download=True, transform=transform)
test_data = datasets.FashionMNIST(root='../Data', train=False, download=True, transform=transform)# 定义类别名称
class_names = ['T-shirt', 'Trouser', 'Sweater', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Boot']# 使用DataLoader加载数据
train_loader = DataLoader(train_data, batch_size=10, shuffle=True)
test_loader = DataLoader(test_data, batch_size=10, shuffle=False)

2. EDA(探索性数据分析)

        在训练模型之前,我们可以通过批量可视化数据来对数据集有一个直观的认识。

for images, labels in train_loader: breakprint('标签: ', labels.numpy())
print('类别: ', *[class_names[i] for i in labels])# 显示图像
im = make_grid(images, nrow=10)
plt.figure(figsize=(12,4))
plt.imshow(np.transpose(im.numpy(), (1, 2, 0)))

3. 模型训练

        接下来我们定义一个卷积神经网络模型,该模型包含两个卷积层、两个池化层和两个全连接层。

class ConvolutionalNetwork(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1, 6, 3, 1)  # 输入通道数为1(灰度图),输出通道数为6,卷积核大小为3x3self.conv2 = nn.Conv2d(6, 16, 3, 1)  # 第二个卷积层,输出通道数为16self.fc1 = nn.Linear(5*5*16, 100)  # 全连接层,输入尺寸根据卷积结果计算self.fc2 = nn.Linear(100, 10)  # 输出层,10个类别def forward(self, X):X = F.relu(self.conv1(X))X = F.max_pool2d(X, 2, 2)X = F.relu(self.conv2(X))X = F.max_pool2d(X, 2, 2)X = X.view(-1, 5*5*16)X = F.relu(self.fc1(X))X = self.fc2(X)return Xtorch.manual_seed(101)
model = ConvolutionalNetwork()

模型参数统计

        我们可以通过函数统计模型的可训练参数数量。

def count_parameters(model):params = [p.numel() for p in model.parameters() if p.requires_grad]for item in params:print(f'{item:>6}')print(f'总参数数量: {sum(params):>6}')count_parameters(model)

定义损失函数与优化器

        我们使用交叉熵损失函数(CrossEntropyLoss)和Adam优化器。

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

模型训练

        接下来我们开始训练模型,设定训练5个epoch。

epochs = 5for i in range(epochs):for X_train, y_train in train_loader:# 应用模型y_pred = model(X_train)loss = criterion(y_pred, y_train)# 更新参数optimizer.zero_grad()loss.backward()optimizer.step()print(f'{i+1} / {epochs} 轮训练完成')

4. 模型测试

        训练完成后,我们对测试数据进行推理并计算准确率。

model.eval()with torch.no_grad():correct = 0for X_test, y_test in test_loader:y_val = model(X_test)predicted = torch.max(y_val,1)[1]correct += (predicted == y_test).sum()print(f'测试准确率: {correct.item()}/{len(test_data)} = {correct.item()*100/(len(test_data)):.3f}%')

结语

        在本文中,我们使用PyTorch搭建了一个简单的卷积神经网络(CNN)来处理Fashion MNIST图像分类任务。通过应用CNN模型,我们能够成功地对服装图像进行分类,并取得了不错的准确率。尽管这是一个入门级别的模型,且有提升空间,但它展示了CNN在图像分类任务中的强大性能。

        接下来,读者可以尝试调整模型结构,增加或减少卷积层、全连接层的数量,或尝试使用不同的优化器和学习率,看看这些变化如何影响模型的准确性。此外,使用更复杂的深度学习技术(如迁移学习或更深的网络)也能进一步提升模型的表现。

        通过这个案例,相信大家对卷积神经网络的基本结构和应用有了更直观的了解。接下来的篇章中,我们将继续探索更多深度学习应用,带领大家深入了解人工智能的实际应用场景。敬请期待!

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!


http://www.ppmy.cn/server/124428.html

相关文章

Java项目实战II基于Java+Spring Boot+MySQL的新闻稿件管理系统(源码+数据库+文档)

目录 一、前言 二、技术介绍 三、系统实现 四、文档参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发,CSDN平台Java领域新星创作者,专注于大学生项目实战开发、讲解和毕业答疑辅导。获取源码联系方式请查看文末 一、前言 前在信息爆…

mysql数据库设置主从同步

mysql数据库设置主从同步 环境 mysql主库版本MySQL-5.6.40-2.sles12.x86_64 mysql从库版本mysql-5.7.21-linux-glibc2.12-x86_64 一、主库配置 修改主库my.cnf配置 [mysqld] #server_id 1 #唯一标识,主库从库不能重复 #log_bin mysql-bin …

jdk1.6版本发送HTTPS请求,报错Could not generate DH keypair问题解决

Could not generate DH keypair问题 这个问题一般出现在因为jdk版本过低,而接收请求的服务器设置接收的加密算法不持支这个从而导致的,解决方式有多个: 直接了当更新jdk版本,更新到服务器所支持的jdk版本很多时候,更新jdk版本会…

COSCon'24 第九届中国开源年会议题征集正式启动

一年一度的开源盛会,COSCon24 第九届中国开源年会暨开源社十周年嘉年华将于2024年11月2-3日在中关村国家自主创新示范区会议中心举办。在为期2天的大会中,我们将为大家带来精彩纷呈的 Keynote 主题演讲(上午),和百花齐…

Java项目实战II基于Java+Spring Boot+MySQL的植物健康系统(开发文档+源码+数据库)

目录 目录 一、前言 二、技术介绍 三、系统实现 四、文档参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发,CSDN平台Java领域新星创作者,专注于大学生项目实战开发、讲解和毕业答疑辅导。获取源码联系方式请查看文末 一、前言 随着…

MacOS上安装MiniConda的详细步骤

前言 MiniConda是一种环境配置工具。在不同的开发项目中,我们会使用到不同版本的Python和第三方库(例如Numpy、Pandas)。如果不使用环境配置工具,每次开发都需要清除电脑里上一次开发的环境和配置文件。为了在同一台机器上同时开发多个项目&…

【C语言从不挂科到高绩点】23-指针05-结构体指针【重点知识】

Hello!彦祖们,俺又回来了!!!,继续给大家分享 《C语言从不挂科到高绩点》课程!! 本节将为大家讲解C语言中非常重要的知识点-指针: 本套课程将会从0基础讲解C语言核心技术,适合人群: 大学中开设了C语言课程的同学想要专升本或者考研的同学想要考计算机等级证书的同学想…

SpinalHDL之结构(七)

本文作为SpinalHDL学习笔记第六十七篇,介绍SpinalHDL的保留名称(Preserving names)。 目录: 1.简介(Introduction) 2.可命名的基础类(Nameable base class) 3.从Scala中提取名字(Name extraction from Scala) 4.模块中的区域(Area in a Component) 5.函数中的区域(Area …