从零构建CNN:框架与自定义实现对比

embedded/2025/3/13 18:38:05/

文章目录

    • 引言
    • 项目结构
    • 一、代码结构解析
      • 1.1 训练流程控制 (main.py)
      • 1.2 PyTorch实现的CNN模型 (cnn_pytorch.py)
      • 1.3 自定义实现CNN模型 (cnn_custom.py)
    • 二、关键算法细节剖析
      • 2.1 卷积操作
      • 2.2 自定义实现卷积层
      • 2.3 ReLU与池化
      • 2.4 全连接层
    • 总结

引言

卷积神经网络 (Convolutional Neural Network, CNN) 是图像识别和处理中的核心技术,特别在计算机视觉任务中广泛应用。本文通过在一个简单的图像分类任务中对比PyTorch实现自定义实现两种方案,解析CNN的关键技术细节。

项目结构

.
├── README.md
├── cnn_pytorch.py
├── cnn_custom.py
└── main.py

项目地址:https://github.com/tangpan360/cnn-from-scratch.git

一、代码结构解析

1.1 训练流程控制 (main.py)

# main.py
import torch
import torch.optim as optim
from cnn_pytorch import SimpleCNN  # 导入模型类
from cnn_custom import SimpleCNNCustom  # 导入自定义模型# 超参数设定
batch_size = 4  # 一次处理 4 张图片
channels = 2  # 2通道
height = 32  # 图片高度
width = 32  # 图片宽度
num_classes = 10  # 假设有 10 个分类
epochs = 10  # 训练轮数
learning_rate = 0.001  # 学习率# 生成一个随机数据集,模拟训练集
train_data = torch.randn(100, channels, height, width)
train_labels = torch.randint(0, num_classes, (100,))# 初始化 pytorch 模型
# model = SimpleCNN(num_classes=num_classes)
# 初始化自定义模型
model = SimpleCNNCustom(num_classes=num_classes)# 定义损失函数(交叉熵损失)
criterion = torch.nn.CrossEntropyLoss()# 选择优化器(Adam)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# 训练过程
for epoch in range(epochs):model.train()  # 切换到训练模式# 生成当前 batch 的数据inputs = train_datalabels = train_labels# 清空之前的梯度optimizer.zero_grad()# 前向传播outputs = model(inputs)# 计算损失loss = criterion(outputs, labels)# 反向传播loss.backward()# 更新模型参数optimizer.step()# 每个 epoch 打印一次损失print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}")

cnn_pytorchpy_77">1.2 PyTorch实现的CNN模型 (cnn_pytorch.py)

# cnn_pytorch.py
import torch
import torch.nn as nn# 定义包含两层卷积的 CNN 模型
class SimpleCNN(nn.Module):def __init__(self, num_classes=10):super(SimpleCNN, self).__init__()# 第一层卷积:2 通道 -> 16 通道,卷积核大小 3x3,步长 1,填充 1self.conv1 = nn.Conv2d(in_channels=2, out_channels=16, kernel_size=3, stride=1, padding=1)# 第二层卷积:16 通道 -> 32 通道,卷积核大小 3x3,步长 1,填充 1self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)self.relu = nn.ReLU()  # ReLU 激活函数self.pool = nn.MaxPool2d(kernel_size=2, stride=2)  # 2x2 最大池化# 全连接层:将特征展平后,映射到 num_classes 类别self.fc1 = nn.Linear(32 * 8 * 8, num_classes)def forward(self, x):# 第一层卷积 -> ReLU -> 池化x = self.conv1(x)x = self.relu(x)x = self.pool(x)# 第二层卷积 -> ReLU -> 池化x = self.conv2(x)x = self.relu(x)x = self.pool(x)# 展平x = torch.flatten(x, start_dim=1)# 全连接层x = self.fc1(x)return x

cnn_custompy_118">1.3 自定义实现CNN模型 (cnn_custom.py)

# cnn_custom.py
import torch
import torch.nn as nn
import torch.nn.functional as F# 手动实现卷积操作
class Conv2dCustom(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):super(Conv2dCustom, self).__init__()# 卷积核的初始化使用 Xavier 均匀分布self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))  # 初始化为随机值nn.init.xavier_uniform_(self.weight)  # 使用 Xavier 均匀分布初始化权重self.bias = nn.Parameter(torch.zeros(out_channels))  # 偏置初始化为零self.stride = strideself.padding = paddingdef forward(self, x):# 输入的维度是 (batch_size, in_channels, height, width)batch_size, in_channels, height, width = x.size()# 计算输出的尺寸kernel_size = self.weight.size(2)out_height = (height + 2 * self.padding - kernel_size) // self.stride + 1out_width = (width + 2 * self.padding - kernel_size) // self.stride + 1# 扩展输入 x 到与卷积核对齐的形式x_padded = F.pad(x, (self.padding, self.padding, self.padding, self.padding))# 进行卷积计算out = torch.zeros(batch_size, self.weight.size(0), out_height, out_width).to(x.device)for i in range(out_height):for j in range(out_width):# 获取当前卷积窗口h_start = i * self.strideh_end = h_start + kernel_sizew_start = j * self.stridew_end = w_start + kernel_size# 提取当前窗口的数据x_slice = x_padded[:, :, h_start:h_end, w_start:w_end]  # shape: (batch_size, in_channels, kernel_size, kernel_size)# 调整 x_slice 和 self.weight 的形状以便广播x_slice = x_slice.unsqueeze(1)  # shape: (batch_size, 1, in_channels, kernel_size, kernel_size)weight = self.weight.unsqueeze(0)  # shape: (1, out_channels, in_channels, kernel_size, kernel_size)# 计算卷积结果element_wise = x_slice * weight  # 逐元素相乘conv_result = element_wise.sum(dim=(2, 3, 4))  # 在指定维度求和out[:, :, i, j] = conv_result + self.biasreturn out# 自定义 ReLU 激活函数
class ReLUCustom(nn.Module):def forward(self, x):result = torch.max(x, torch.tensor(0.0).to(x.device))return result# 自定义最大池化层
class MaxPool2dCustom(nn.Module):def __init__(self, kernel_size, stride=None, padding=0):super(MaxPool2dCustom, self).__init__()self.kernel_size = kernel_sizeself.stride = stride if stride is not None else kernel_sizeself.padding = paddingdef forward(self, x):batch_size, in_channels, height, width = x.size()# 计算输出的尺寸out_height = (height + 2 * self.padding - self.kernel_size) // self.stride + 1out_width = (width + 2 * self.padding - self.kernel_size) // self.stride + 1# 对输入进行 paddingx_padded = F.pad(x, (self.padding, self.padding, self.padding, self.padding))# 最大池化操作out = torch.zeros(batch_size, in_channels, out_height, out_width).to(x.device)for i in range(out_height):for j in range(out_width):# 计算窗口的起始和结束位置h_start = i * self.strideh_end = h_start + self.kernel_sizew_start = j * self.stridew_end = w_start + self.kernel_size# 提取当前窗口的数据x_slice = x_padded[:, :, h_start:h_end, w_start:w_end]# 对窗口进行最大池化max_values = x_slice.amax(dim=(2, 3))out[:, :, i, j] = max_valuesreturn out# 自定义全连接层
class LinearCustom(nn.Module):def __init__(self, in_features, out_features):super(LinearCustom, self).__init__()# 使用 Kaiming 初始化权重,适用于 ReLU 激活函数self.weight = nn.Parameter(torch.empty(in_features, out_features))nn.init.kaiming_normal_(self.weight, mode='fan_out', nonlinearity='relu')# 初始化偏置为 0self.bias = nn.Parameter(torch.zeros(out_features))def forward(self, x):# 将输入展平x = x.view(x.size(0), -1)# 计算线性变换结果output = torch.matmul(x, self.weight) + self.biasreturn output# 定义自定义 CNN 模型
class SimpleCNNCustom(nn.Module):def __init__(self, num_classes=10):super(SimpleCNNCustom, self).__init__()# 第一层卷积self.conv1 = Conv2dCustom(in_channels=2, out_channels=16, kernel_size=3, padding=1)# 第二层卷积self.conv2 = Conv2dCustom(in_channels=16, out_channels=32, kernel_size=3, padding=1)# 激活函数self.relu = ReLUCustom()# 最大池化self.pool = MaxPool2dCustom(kernel_size=2, stride=2)# 全连接层self.fc1 = LinearCustom(32 * 8 * 8, num_classes)def forward(self, x):# 第一层卷积 -> 激活 -> 池化x = self.conv1(x)x = self.relu(x)x = self.pool(x)# 第二层卷积 -> 激活 -> 池化x = self.conv2(x)x = self.relu(x)x = self.pool(x)# 展平x = x.view(x.size(0), -1)# 全连接层x = self.fc1(x)return x

二、关键算法细节剖析

2.1 卷积操作

卷积操作通过滑动卷积核进行局部特征的提取,图像的每个区域与卷积核进行点乘,并通过步长和填充控制卷积的输出尺寸。

2.2 自定义实现卷积层

我们手动实现了卷积层 Conv2dCustom,它通过对输入图像进行切片和点乘操作模拟了传统的卷积操作。此外,我们还实现了自定义的激活函数和池化层,用于更好地理解底层实现。

2.3 ReLU与池化

我们使用了自定义的 ReLU 激活函数 ReLUCustom 和池化层 MaxPool2dCustom,其中池化操作用于降低特征图的空间尺寸,同时保留重要的空间信息。

2.4 全连接层

全连接层用于将卷积层提取到的特征映射到最终的类别空间,确保每个特征都对分类决策产生影响。

总结

通过对比 PyTorch 实现与自定义实现的 CNN 模型,我们能够更好地理解 CNN 的各个构成部分。自定义实现展示了如何从零开始构建卷积神经网络并理解其中的关键操作,而 PyTorch 提供了高度优化的卷积层,可以快速构建高效的模型。


http://www.ppmy.cn/embedded/172321.html

相关文章

AJAX的作用

AJAX(Asynchronous JavaScript And XML)的工作原理基于浏览器与服务器的异步通信,其核心细节可分为以下几个关键步骤: 1. 事件触发与请求创建 触发源:用户操作(点击按钮、输入文本等)或定时事件…

【每日八股】Redis篇(七):集群

目录 Redis 集群模式有哪些?Redis 切片集群的工作原理?哈希槽和 Redis 节点如何对应?主从模式的同步过程?全量同步增量同步 主服务器如何知道要将哪些增量数据发送给从服务器?如何避免主从数据不一致?主从架…

每日算法:力扣343.整数差分(动态规划)

题目: 给定一个正整数 n ,将其拆分为 k 个 正整数 的和( k > 2 ),并使这些整数的乘积最大化。 返回 你可以获得的最大乘积 。 示例 1: 输入: n 2 输出: 1 解释: 2 1 1, 1 1 1。 示例 2: 输入: n 10 输出…

3.11记录

leetcode刷题: 1. 334. 递增的三元子序列 - 力扣(LeetCode) 方法一:使用贪心算法求解 class Solution(object):def increasingTriplet(self, nums):first nums[0]second float(inf)for i in nums:if i>second:return Truee…

c++20 Concepts的简写形式与requires 从句形式

c20 Concepts的简写形式与requires 从句形式 原始写法(简写形式)等效写法(requires 从句形式)关键区别说明:组合多个约束的示例:两种形式的编译结果:更复杂的约束示例:标准库风格的约…

【网络编程】WSAAsyncSelect 模型

十、基于I/O模型的网络开发 接着上次的博客继续分享:select模型 10.8 异步选择模型WSAAsyncSelect 10.8.1 基本概念 WSAAsyncSelect模型是Windows socket的一个异步I/O 模型,利用这个模型,应用程序 可在一个套接字上接收以Windows 消息为基…

力扣-哈希表-844 比较含退格的字符串

思路和时间复杂度 思路&#xff1a;利用栈完成出栈操作时间复杂度&#xff1a; 代码 class Solution { public:bool backspaceCompare(string s, string t) {stack<char> ss;stack<char> tt;for(int i 0; i < s.size(); i){if(s[i] ! #){ss.push(s[i…

Redis相关面试题

以下是150道Redis相关面试题&#xff1a; Redis基础概念 1. Redis是什么&#xff1f; Redis是一个开源的、基于内存的高性能键值存储数据库&#xff0c;常用于缓存、消息队列等场景。 2. Redis的特点有哪些&#xff1f; • 高性能&#xff0c;读写速度快。 • 支持多种数据…