CNN-day5-经典神经网络LeNets5

devtools/2025/2/11 7:53:53/

经典神经网络-LeNets5

1998年Yann LeCun等提出的第一个用于手写数字识别问题并产生实际商业(邮政行业)价值的卷积神经网络

参考:论文笔记:Gradient-Based Learning Applied to Document Recognition-CSDN博客

1 网络模型结构

整体结构解读:

输入图像:32×32×1

三个卷积层:

C1:输入图片32×32,6个5×5卷积核 ,输出特征图大小28×28(32-5+1)=28,一个bias参数;

可训练参数一共有:(5×5+1)×6=156

C3 :输入图片14×14,16个5×5卷积核,有6×3+6×4+3×4+1×6=60个通道,输出特征图大小10×10((14-5)/1+1),一个bias参数;

可训练参数一共有:6(3×5×5+1)+6×(4×5×5+1)+3×(4×5×5+1)+1×(6×5×5+1)=1516

C3的非密集的特征图连接:

C3的前6个特征图与S2层相连的3个特征图相连接,后面6个特征图与S2层相连的4个特征图相连 接,后面3个特征图与S2层部分不相连的4个特征图相连接,最后一个与S2层的所有特征图相连。 采用非密集连接的方式,打破对称性,同时减少计算量,共60组卷积核。主要是为了节省算力。

C5:输入图片5×5,16个5×5卷积核,包括120×16个5×5卷积核 ,输出特征图大小1×1(5-5+1),一个bias参数;

可训练参数一共有:120×(16×5×5+1)=48120

两个池化层S2和S4:

都是2×2的平均池化,并添加了非线性映射

S2(下采样层):输入28×28,采样区域2×2,输入相加,乘以一个可训练参数, 再加上一个可训练偏置,使用sigmoid激活,输出特征图大小:14×14(28/2)

S4(下采样层):输入10×10,采样区域2×2,输入相加,乘以一个可训练参数, 再加上一个可训练偏置,使用sigmoid激活,输出特征图大小:5×5(10/2)

两个全连接层:

第一个全连接层:输入120维向量,输出84个神经元,计算输入向量和权重向量之间的点积,再加上一个偏置,结果通过sigmoid函数输出。84的原因是:字符编码是ASCII编码,用7×12大小的位图表示,-1白色1黑色,84可以用于对每一个像素点的值进行估计。

第二个全连接层(Output层-输出层):输出 10个神经元 ,共有10个节点,代表数字0-9。

所有激活函数采用Sigmoid

2 网络模型实现

2.1模型定义

import torch
import torch.nn as nn
​
​
class LeNet5s(nn.Module):def __init__(self):super(LeNet5s, self).__init__()  # 继承父类# 第一个卷积层self.C1 = nn.Sequential(nn.Conv2d(in_channels=1,  # 输入通道out_channels=6,  # 输出通道kernel_size=5,  # 卷积核大小),nn.ReLU(),)# 池化:平均池化self.S2 = nn.AvgPool2d(kernel_size=2)
​# C3:3通道特征融合单元self.C3_unit_6x3 = nn.Conv2d(in_channels=3,out_channels=1,kernel_size=5,)# C3:4通道特征融合单元self.C3_unit_6x4 = nn.Conv2d(in_channels=4,out_channels=1,kernel_size=5,)
​# C3:4通道特征融合单元,剔除中间的1通道self.C3_unit_3x4_pop1 = nn.Conv2d(in_channels=4,out_channels=1,kernel_size=5,)
​# C3:6通道特征融合单元self.C3_unit_1x6 = nn.Conv2d(in_channels=6,out_channels=1,kernel_size=5,)
​# S4:池化self.S4 = nn.AvgPool2d(kernel_size=2)# 全连接层self.fc1 = nn.Sequential(nn.Linear(in_features=16 * 5 * 5, out_features=120), nn.ReLU())self.fc2 = nn.Sequential(nn.Linear(in_features=120, out_features=84), nn.ReLU())self.fc3 = nn.Linear(in_features=84, out_features=10)
​def forward(self, x):# 训练数据批次大小batch_sizenum = x.shape[0]
​x = self.C1(x)x = self.S2(x)# 生成一个empty张量outchannel = torch.empty((num, 0, 10, 10))# 6个3通道的单元for i in range(6):# 定义一个元组:存储要提取的通道特征的下标channel_idx = tuple([j % 6 for j in range(i, i + 3)])x1 = self.C3_unit_6x3(x[:, channel_idx, :, :])outchannel = torch.cat([outchannel, x1], dim=1)
​# 6个4通道的单元for i in range(6):# 定义一个元组:存储要提取的通道特征的下标channel_idx = tuple([j % 6 for j in range(i, i + 4)])x1 = self.C3_unit_6x4(x[:, channel_idx, :, :])outchannel = torch.cat([outchannel, x1], dim=1)
​# 3个4通道的单元,先拿五个,干掉中那一个for i in range(3):# 定义一个元组:存储要提取的通道特征的下标channel_idx = tuple([j % 6 for j in range(i, i + 5)])# 删除第三个元素channel_idx = channel_idx[:2] + channel_idx[3:]print(channel_idx)x1 = self.C3_unit_3x4_pop1(x[:, channel_idx, :, :])outchannel = torch.cat([outchannel, x1], dim=1)
​x1 = self.C3_unit_1x6(x)# 平均池化outchannel = torch.cat([outchannel, x1], dim=1)outchannel = nn.ReLU()(outchannel)
​x = self.S4(outchannel)# 对数据进行变形x = x.view(x.size(0), -1)# 全连接层x = self.fc1(x)x = self.fc2(x)# TODO:SOFTMAXoutput = self.fc3(x)
​return output
​
​
def test001():net = LeNet5s()# 随机一个测试数据input = torch.randn(128, 1, 32, 32)output = net(input)print(output.shape)pass
​
​
if __name__ == "__main__":test001()

2.2全局变量

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import os
​
dir = os.path.dirname(__file__)
modelpath = os.path.join(dir, "weight/model.pth")
datapath = os.path.join(dir, "data")
​
# 数据预处理和加载
transform = transforms.Compose([transforms.Resize((32, 32)),  # 调整输入图像大小为32x32transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,)),]
)
​

2.3模型训练

def train():
​trainset = torchvision.datasets.MNIST(root=datapath, train=True, download=True, transform=transform)trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
​# 实例化模型net = LeNet5()
​# 使用MSELoss作为损失函数criterion = nn.MSELoss()
​# 使用SGD优化器optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
​# 训练模型num_epochs = 10for epoch in range(num_epochs):running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = data
​# 将labels转换为one-hot编码labels_one_hot = torch.zeros(labels.size(0), 10).scatter_(1, labels.view(-1, 1), 1.0)labels_one_hot = labels_one_hot.to(torch.float32)optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels_one_hot)loss.backward()optimizer.step()
​running_loss += loss.item()if i % 100 == 99:print(f"[{epoch + 1}, {i + 1}] loss: {running_loss / 100:.3f}")running_loss = 0.0# 保存模型参数torch.save(net.state_dict(), modelpath)print("Finished Training")

2.4验证

def vaild():
​testset = torchvision.datasets.MNIST(root=datapath, train=False, download=True, transform=transform)testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)# 实例化模型net = LeNet5()net.load_state_dict(torch.load(modelpath))# 在测试集上测试模型correct = 0total = 0with torch.no_grad():for data in testloader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()
​print(f"验证集: {100 * correct / total:.2f}%")

http://www.ppmy.cn/devtools/157869.html

相关文章

Django框架丨从零开始的Django入门学习

Django 是一个用于构建 Web 应用程序的高级 Python Web 框架,Django是一个高度模块化的框架,使用 Django,只要很少的代码,Python 的程序开发人员就可以轻松地完成一个正式网站所需要的大部分内容,并进一步开发出全功能…

本地大模型编程实战(11)与外部工具交互(2)

文章目录 准备定义工具方法创建提示词生成工具方法实参以 json 格式返回实参自定义 JsonOutputParser返回 json 调用工具方法定义通用方法用 链 返回结果返回结果中包含工具输入 总结代码 在使用 LLM(大语言模型) 时,经常需要调用一些自定义的工具方法完成特定的任务…

【蓝耕元生代智算云平台】一键部署 DeepSeek人工智能模型

欢迎来到ZyyOvO的博客✨,一个关于探索技术的角落,记录学习的点滴📖,分享实用的技巧🛠️,偶尔还有一些奇思妙想💡 本文由ZyyOvO原创✍️,感谢支持❤️!请尊重原创&#x1…

Http和Socks的区别?

HTTP 和 SOCKS 的区别 HTTP 和 SOCKS 都是用于网络通信的协议,但它们在工作原理、应用场景和实现方式上有显著的区别。以下是详细的对比和说明。 一、HTTP 协议 1. 定义 HTTP(HyperText Transfer Protocol)是用于传输超文本数据的应用层协…

MySQL——CRUD

一、Create新增数据行 语法: 其中,column表示要插入的列的列名,value_list表示列名对应的数据,且可以同时插入多行 一、单行数据 全列插入 全列插入可以不在表名student后面指定列名,也就是说,如果没有指定…

如何清理浏览器一段时间以前的缓存

浏览器缓存是浏览器为了提高网页加载速度而自动存储的数据,包括图片、脚本、样式表等文件。然而,随着时间的推移,这些缓存数据可能会占用大量磁盘空间,甚至影响浏览器的性能和稳定性。因此,定期清理浏览器缓存是一个良…

《IP-Adapter: 适用于文本到图像扩散模型的文本兼容图像提示适配器》学习笔记

paper:2308.06721 GitHub:tencent-ailab/IP-Adapter: The image prompt adapter is designed to enable a pretrained text-to-image diffusion model to generate images with image prompt. 目录 摘要 1、介绍 2、相关工作 2.1 文本到图像扩散模型 2.2 大模…

【vscode源码】如何编译运行vscode及过程中问题解决

Visual Studio Code(VSCode)作为一款流行的开源编辑器,市面上很多基于vscode的套壳APP,本文将详细介绍如何编译和运行VSCode的源码,并总结一些常见问题以及解决方案,帮助开发者顺利二次开发。 1. 准备工作(…