cnn以及例子

news/2025/2/16 5:52:03/

cnn_0">cnn

CNN 即卷积神经网络(Convolutional Neural Network),是一种专门为处理具有网格结构数据(如图像、音频)而设计的深度学习模型,在计算机视觉、语音识别等诸多领域都有广泛应用。以下是 CNN 的详细介绍:
基本原理
卷积层:是 CNN 的核心组成部分,通过卷积核在数据上滑动进行卷积操作,自动提取数据中的局部特征。例如,在处理图像时,卷积核可以检测图像中的边缘、线条等简单特征。卷积操作大大减少了模型的参数数量,降低计算量,同时也能有效地捕捉数据的空间相关性。
池化层:通常接在卷积层之后,主要作用是对数据进行下采样,减少数据的维度,降低计算量,同时保留数据的主要特征。常见的池化方法有最大池化和平均池化。最大池化是取卷积核覆盖区域内的最大值,平均池化则是取平均值。
全连接层:一般位于 CNN 的最后部分,将池化层输出的特征图展开成一维向量,然后与全连接的神经元进行连接,用于对提取到的特征进行综合判断和分类。全连接层的神经元之间是完全连接的,其权重矩阵包含了模型对数据整体特征的学习结果。
网络结构
输入层:负责接收原始数据,对于图像数据,通常是一个三维的张量,维度分别代表图像的高度、宽度和通道数(如 RGB 图像通道数为 3)。
隐藏层:包含多个卷积层、池化层和可能的其他类型的层(如批量归一化层、激活函数层等)。这些层通过不断地卷积、池化等操作,逐步提取数据的高级特征。
输出层:根据具体任务的不同,输出层的形式也有所不同。在图像分类任务中,输出层通常是一个具有 softmax 激活函数的全连接层,用于输出各类别的概率分布;在目标检测任务中,输出层可能包含边界框的坐标信息和类别信息等。

反向传播算法

误差计算:在训练过程中,首先计算模型输出与真实标签之间的误差,常用的损失函数有交叉熵损失函数、均方误差损失函数等。以交叉熵损失函数为例,它衡量的是模型预测的概率分布与真实标签的概率分布之间的差异。
误差反向传播:将误差从输出层反向传播到输入层,通过链式求导法则计算每一层的参数梯度。在卷积层中,需要计算卷积核的梯度;在全连接层中,需要计算权重矩阵和偏置项的梯度。
参数更新:根据计算得到的梯度,使用优化算法(如随机梯度下降、Adagrad、Adadelta、Adam 等)来更新模型的参数,使得损失函数逐渐减小,模型的性能不断提升。

应用领域

计算机视觉:在图像分类、目标检测、图像分割、人脸识别等任务中取得了巨大成功。例如,在图像分类中,CNN 可以准确地识别出图像中的物体类别;在目标检测中,能够定位并识别图像中的多个目标物体。
语音识别:用于对语音信号进行特征提取和分类,将语音转换为文字。CNN 可以有效地捕捉语音信号中的声学特征,提高语音识别的准确率。
自然语言处理:在文本分类、情感分析、机器翻译等任务中也有应用。通过将文本数据转换为向量表示,然后利用 CNN 提取文本中的局部特征,进行语义理解和分类。

识别猫狗的例子

# 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt# 2. 数据加载和预处理
# 定义数据预处理操作的组合
# Compose 函数用于将多个数据预处理操作组合在一起,按顺序依次执行
transform = transforms.Compose([# 将图片的大小调整为 224x224 像素# 这是因为许多预训练的模型输入要求为 224x224 大小transforms.Resize((224, 224)),  # 将图片从 PIL 图像格式转换为 PyTorch 张量格式# 张量是 PyTorch 中用于存储和处理数据的基本数据结构transforms.ToTensor(),  # 对图像进行归一化处理# 这里使用的均值和标准差是在 ImageNet 数据集上统计得到的# 归一化有助于模型更快收敛和提高稳定性transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  
])# 加载训练集数据
# ImageFolder 函数用于从指定的根目录加载图像数据集
# 根目录下的每个子文件夹代表一个类别,文件夹名即为类别名
train_dataset = datasets.ImageFolder(root='train', transform=transform)
# 加载测试集数据
test_dataset = datasets.ImageFolder(root='test', transform=transform)# 创建训练集的数据加载器
# DataLoader 用于将数据集封装成可迭代的数据加载对象
# batch_size 表示每次从数据集中取出的样本数量
# shuffle=True 表示在每个训练周期开始时对数据进行打乱,增加数据的随机性
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# 创建测试集的数据加载器
# 测试集不需要打乱数据,所以 shuffle=False
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)# 3. 定义神经网络模型
# 定义一个简单的卷积神经网络类,继承自 nn.Module
class SimpleCNN(nn.Module):def __init__(self):# 调用父类的构造函数super(SimpleCNN, self).__init__()# 第一个卷积层# 输入通道数为 3(对应 RGB 三个通道)# 输出通道数为 16,表示该层会提取 16 种不同的特征# 卷积核大小为 3x3,padding=1 表示在图像边缘填充 1 个像素,以保持输出图像大小不变self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)# 激活函数 ReLU,用于引入非线性因素self.relu1 = nn.ReLU()# 第一个最大池化层# 池化核大小为 2x2,用于减小特征图的尺寸,同时保留重要特征self.pool1 = nn.MaxPool2d(2)# 第二个卷积层# 输入通道数为 16,与上一层的输出通道数一致# 输出通道数为 32,表示该层会提取 32 种不同的特征self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)# 激活函数 ReLUself.relu2 = nn.ReLU()# 第二个最大池化层self.pool2 = nn.MaxPool2d(2)# 第一个全连接层# 输入特征数为 32 * 56 * 56,这是经过两次池化后特征图的尺寸和通道数的乘积# 输出特征数为 128self.fc1 = nn.Linear(32 * 56 * 56, 128)# 激活函数 ReLUself.relu3 = nn.ReLU()# 第二个全连接层# 输入特征数为 128,与上一层的输出特征数一致# 输出特征数为 2,对应猫和狗两个类别self.fc2 = nn.Linear(128, 2)  def forward(self, x):# 前向传播过程# 输入数据经过第一个卷积层、ReLU 激活函数和最大池化层x = self.pool1(self.relu1(self.conv1(x)))# 再经过第二个卷积层、ReLU 激活函数和最大池化层x = self.pool2(self.relu2(self.conv2(x)))# 将多维的特征图展平为一维向量,以便输入到全连接层x = x.view(-1, 32 * 56 * 56)# 经过第一个全连接层和 ReLU 激活函数x = self.relu3(self.fc1(x))# 经过第二个全连接层,得到最终的输出x = self.fc2(x)return x# 创建模型实例
model = SimpleCNN()# 4. 定义损失函数和优化器
# 定义交叉熵损失函数,用于分类问题
# 交叉熵损失函数可以衡量模型预测结果与真实标签之间的差异
criterion = nn.CrossEntropyLoss()
# 定义 Adam 优化器,用于更新模型的参数
# lr 表示学习率,控制参数更新的步长
optimizer = optim.Adam(model.parameters(), lr=0.001)# 5. 训练模型
# 定义训练的总周期数
num_epochs = 10
# 用于记录每个周期的训练损失
train_losses = []# 开始训练循环,遍历每个周期
for epoch in range(num_epochs):# 初始化每个周期的累计损失running_loss = 0.0# 遍历训练集中的每个批次数据for i, (images, labels) in enumerate(train_loader):# 清空优化器中的梯度信息# 因为 PyTorch 会累积梯度,所以每个批次都需要清空optimizer.zero_grad()# 将输入图像数据输入到模型中,得到模型的预测输出outputs = model(images)# 计算预测输出与真实标签之间的损失loss = criterion(outputs, labels)# 进行反向传播,计算损失函数关于模型参数的梯度loss.backward()# 根据计算得到的梯度更新模型的参数optimizer.step()# 累加当前批次的损失running_loss += loss.item()# 计算当前周期的平均损失epoch_loss = running_loss / len(train_loader)# 将当前周期的平均损失添加到损失记录列表中train_losses.append(epoch_loss)# 打印当前周期的损失信息print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}')# 绘制训练损失曲线
# 以周期数为横坐标,训练损失为纵坐标绘制曲线
plt.plot(train_losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.show()# 6. 评估模型
# 将模型设置为评估模式
# 在评估模式下,一些层(如 Dropout、BatchNorm 等)的行为会发生变化
model.eval()
# 初始化正确预测的样本数
correct = 0
# 初始化总的样本数
total = 0
# 关闭梯度计算,因为在评估阶段不需要计算梯度
with torch.no_grad():# 遍历测试集中的每个批次数据for images, labels in test_loader:# 将输入图像数据输入到模型中,得到模型的预测输出outputs = model(images)# 找到每个样本预测概率最大的类别索引# _ 表示忽略的变量,predicted 为预测的类别索引_, predicted = torch.max(outputs.data, 1)# 累加当前批次的样本数total += labels.size(0)# 累加当前批次中预测正确的样本数correct += (predicted == labels).sum().item()# 计算并打印模型在测试集上的准确率
print(f'Accuracy on test set: {100 * correct / total:.2f}%')

代码解释

数据加载和预处理:使用 torchvision.datasets.ImageFolder 加载图片数据集,并使用 transforms 对图片进行预处理,包括调整大小、转换为张量和归一化。
模型定义:定义了一个简单的卷积神经网络 SimpleCNN,包含两个卷积层、两个池化层和两个全连接层。
损失函数和优化器:使用交叉熵损失函数 nn.CrossEntropyLoss 和 Adam 优化器 optim.Adam。
模型训练:在训练过程中,遍历训练集,计算损失并进行反向传播更新模型参数。
模型评估:在测试集上评估模型的准确率。
请确保你已经安装了 PyTorch 和 torchvision 库,并且将猫和狗的图片按照指定的文件夹结构放置。如果需要提高模型的性能,可以尝试使用更复杂的预训练模型(如 ResNet、VGG 等)进行微调。

http://www.ppmy.cn/news/1572110.html

相关文章

React源码揭秘 | scheduler 并发更新原理

React 18增加了并发更新特性,开发者可以通过useTransition等hooks延迟执行优先级较低的更新任务,以达到页面平滑切换,不阻塞用户时间的目的。其实现正是依靠scheduler库。 scheduler是一个依赖时间片分片的任务调度器,React团队将…

财务主题数据分析-企业盈利能力分析

企业盈利能力数据主要体现在财务三张表中的利润表里面,盈利能力需要重点需要关注的指标有:毛利率、净利率、净利润增长率、营业成本增长率等; 接下来我们分析一下某上市公司披露的财务数据,看看该企业盈利能力如何: …

基于深度学习的半导体算法原理及应用

摘要 随着半导体产业的持续发展,深度学习技术在该领域的应用日益广泛且深入。本文全面阐述了基于深度学习的半导体算法原理,涵盖卷积神经网络(CNN)、循环神经网络(RNN)及其变体长短时记忆网络(…

SQL入门到精通 理论+实战 -- 在 MySQL 中学习SQL语言

目录 一、环境准备 1、MySQL 8.0 和 Navicat 下载安装 2、准备好的表和数据文件: 二、SQL语言简述 1、数据库基础概念 2、什么是SQL 3、SQL的分类 4、SQL通用语法? 三、DDL(Data Definition Language):数据定义语言 1、…

蓝桥杯之BF算法

算法思想&#xff1a; 找一个子串是否在另一个串中出现 #include<string> string s1 "abcdefghi"; string s2 "def"; int BF() {int i 0;int j 0;while (i < s1.size() && j < s2.size()){if (s1[i] s2[j]){i;j;}else{i i - j …

前端快速生成接口方法

大家好&#xff0c;我是苏麟&#xff0c;今天聊一下OpenApi。 官网 &#xff1a; umijs/openapi - npm 安装命令 npm i --save-dev umijs/openapi 在根目录&#xff08;项目目录下&#xff09;创建文件 openapi.config.js import { generateService } from umijs/openapi// 自…

2.11 sqlite3数据库【数据库的相关操作指令、函数】

练习&#xff1a; 将 epoll 服务器 客户端拿来用 客户端&#xff1a;写一个界面&#xff0c;里面有注册登录 服务器&#xff1a;处理注册和登录逻辑&#xff0c;注册的话将注册的账号密码写入数据库&#xff0c;登录的话查询数据库中是否存在账号&#xff0c;并验证密码是否正确…

Git | 相关命令

相关资料 官网Git 学习教程Git 入门指南Git 的奇技淫巧Git Extras git 命令行扩展工具配置 Git 处理行结束符Git 配置多个 SSH-Key下载相关 Windows 版下载镜像使用 jsdelivr 加速 Github 仓库资源 commit 常用的 type 常用 Git 命令 [xxx] 均为可选参数 git clone # 拷贝一…