【深度学习】经典的深度学习模型-01 开山之作:CNN卷积神经网络LeNet-5

server/2024/10/19 2:39:45/

深度学习】经典的深度学习模型-01 开山之作:CNN卷积神经网络LeNet-5

Note: 草稿状态,持续更新中,如果有感兴趣,欢迎关注。。。

0. 论文信息

@article{lecun1998gradient,
title={Gradient-based learning applied to document recognition},
author={LeCun, Yann and Bottou, L{'e}on and Bengio, Yoshua and Haffner, Patrick},
journal={Proceedings of the IEEE},
volume={86},
number={11},
pages={2278–2324},
year={1998},
publisher={Ieee}
}

基于梯度的学习在文档识别中的应用
在这里插入图片描述
LeNet-5 是一个经典的卷积神经网络(CNN)架构,由 Yann LeCun 等人在 1998 年提出,主要用于手写数字识别任务,特别是在 MNIST 数据集上。
在这里插入图片描述
LeNet-5 的设计对后来的卷积神经网络研究产生了深远影响,该模型具有以下几个特点:

  1. 卷积层:LeNet-5 包含多个卷积层,每个卷积层后面通常会跟一个池化层(Pooling Layer),用于提取图像特征并降低特征图的空间维度。

  2. 池化层:在卷积层之后,LeNet-5 使用池化层来降低特征图的空间分辨率,减少计算量,并增加模型的抽象能力。

  3. 全连接层:在卷积和池化层之后,LeNet-5 包含几个全连接层,用于学习特征之间的复杂关系。

  4. 激活函数:LeNet-5 使用了 Sigmoid 激活函数,这是一种早期的非线性激活函数,用于引入非线性,使得网络可以学习复杂的模式。

  5. Dropout:尽管原始的 LeNet-5 并没有使用 Dropout,但后来的研究者在改进模型时加入了 Dropout 技术,以减少过拟合。

  6. 输出层:LeNet-5 的输出层通常使用 Softmax 激活函数,用于进行多分类任务,输出每个类别的概率。

虽然站在2024年看LeNet-5 的模型结构相对简单,但是时间回拨到1998年,彼时SVM这类算法为主的时代,LeNet-5的出现,不仅证明了卷积神经网络在图像识别任务中的有效性,而且为后续深度神经网络研究的发展带来重要启迪作用,使得我们有幸看到诸如 AlexNet、VGGNet、ResNet 等模型的不断推成出新。

2. 论文摘要

3. 研究背景

4. 算法模型

5. 实验效果

6. 代码实现

以MNIST手写字图像识别问题为例子,采用LeNet5模型进行分类,代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transformsdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")# Define the LeNet-5 model
class LeNet5(nn.Module):def __init__(self):super(LeNet5, self).__init__()self.conv1 = nn.Conv2d(1, 6, 5)  # 1 input image channel, 6 output channels, 5x5 kernelself.pool = nn.MaxPool2d(2, 2)  # pool with window 2x2, stride 2self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 4 * 4, 120)  # 16*4*4 = 256self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 4 * 4)  # flatten the tensorx = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# # Initialize the network
# net = LeNet5()# Initialize the network on GPU
net = LeNet5().to(device)# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)# Data loading
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1000, shuffle=False)# Train the network
for epoch in range(10):  # loop over the dataset multiple timesrunning_loss = 0.0for i, data in enumerate(train_loader, 0):# for cpu# inputs, labels = data# for gpuinputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 2000 == 1999:  # print every 2000 mini-batchesprint(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 2000:.3f}')running_loss = 0.0print('Finished Training')# Test the network on the test data
correct = 0
total = 0
with torch.no_grad():for data in test_loader:# # for cpu# images, labels = data# for gpuimages, labels = data[0].to(device), data[1].to(device)outputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:.2f}%')

注意:这里使用GPU做简单加速。如果没有GPU,可以关闭对应代码,替换为相应的CPU代码即可。
程序运行后结果如下:
在这里插入图片描述
可以看到,在测试数据上的准确率为98.33%!

7. 问题及优化


http://www.ppmy.cn/server/131139.html

相关文章

已经30岁了,想转行从头开始现实吗?什么样的工作算好工作?

我是29岁那年,完成从转行裸辞副业的职业转型。 如果你把职业生涯看成是从现在开始30岁,到你退休那年,中间这么漫长的30年,那么30岁转行完全来得及; 如果你觉得必须在什么年纪,什么时间内必须完成赚到几十…

Facebook减肥产品广告投放攻略

有不少刚开始投放facebook广告的小伙伴会感到疑惑,为什么别人的减肥产品跑的风生水起,销量羡煞旁人,自己的广告要不就是被拒要不就是没有流量,甚至还可能被封号,如果你也有这样的困扰,那一定要看完这篇文章…

开局一个登录框,密码重置全靠翻

一、开局获取登录框 挑选一个目标,直接来到它的统一信息门户 可以看到,框里直接提示默认用户名和密码,这不得来全部费功夫,只要找到个学号和身份证就能进到里面去特 二、震惊!某校竟在公网放出学生这种信息 还是直接…

垂直领域的大模型应该如何构建?RAG还是微调呢?

垂直领域的大模型应该如何构建?RAG还是微调呢? 垂直领域的大模型应该是2024年乃至未来五年内人工智能发展的热门所在。那么该如何构建?是RAG(Retrieval Augmentation Generation,检索增强生成)还是微调&am…

选对软件,音乐剪辑事半功倍!2024年Top 4免费神器推

咱们现在可是生活在一个数字时代,搞音乐创作不用非得去那些贵的要命的录音棚,也不用一大堆复杂的设备了。科技这么发达,各种音频编辑软件多得是,就像雨后春笋一样,给喜欢音乐的人和专业人士带来了超级方便的创作工具。…

苍穹外卖学习笔记(二十)

文章目录 用户端历史订单模块:查询历史订单OrderControllerOrderServiceOrderServiceImpl 查询订单详情OrderControllerOrderServiceOrderServiceImpl 用户端历史订单模块: 查询历史订单 OrderController /*** 历史订单*/GetMapping("/historyOrd…

AI工具 | Notion全新AI集成:搜索、内容生成、数据分析与智能聊天功能发布

新的 Notion AI 集成了搜索、生成内容、分析数据和智能聊天等功能,所有操作都可以在 Notion 内完成。依托于 GPT-4 和 Claude 等先进的 AI 模型,用户可以与 AI 聊天并获取针对各种话题的答案。 随时使用 在 Notion 页面右下角找到 AI 图标,点…

Python爬虫-电影天堂数据阅览

爬取电影的名称与下载地址&#xff1a; import re import requestsdomain "https://www.dytt89.com/" response requests.get(domain,verifyFalse) response.encoding gb2312 # print(response.text)object1 re.compile(r"2024必看热片.*?<ul>(?P&…