ResNet改进(4):添加Inception 结构

devtools/2025/3/14 12:17:06/

1.Inception 结构

Inception 结构是由 Google 提出的经典卷积神经网络架构,首次出现在 2014 年的论文《Going Deeper with Convolutions》中,并在 ImageNet 图像分类竞赛中取得了优异成绩。Inception 结构的目标是通过多尺度卷积和高效计算来提升网络性能,同时减少参数数量。

Inception 的核心思想

Inception 的核心思想是通过并行使用不同大小的卷积核(如 1x1、3x3、5x5)和池化操作,捕捉图像中不同尺度的特征,并将这些特征在通道维度上拼接起来。这种设计能够在不显著增加计算量的情况下提升网络的表达能力。

Inception 模块

Inception 模块是 Inception 网络的基本构建单元,其结构如下:

  1. 1x1 卷积

    • 用于降维或升维,减少计算量。

    • 通过 1x1 卷积调整通道数,减少后续卷积的计算复杂度。

  2. 3x3 卷积

    • 捕捉中等尺度的特征。

  3. 5x5 卷积

    • 捕捉更大尺度的特征。

  4. Max Pooling

    • 通过最大池化提取空间特征,通常后接 1x1 卷积以调整通道数。

  5. 特征拼接

    • 将上述所有分支的输出在通道维度上拼接,形成最终输出。

Inception 的变体

  1. Inception v1(GoogLeNet)

    • 最早的 Inception 结构,引入了 Inception 模块和辅助分类器。

  2. Inception v2

    • 加入了 Batch Normalization,加速训练并提升性能。

  3. Inception v3

    • 进一步优化,将大卷积核分解为多个小卷积核(如用两个 3x3 卷积代替 5x5 卷积),减少计算量。

  4. Inception v4

    • 结合了 Inception 和 ResNet 的思想,引入了残差连接。

  5. Inception-ResNet

    • 在 Inception 模块中加入了残差连接,进一步提升性能。

Inception 的优势

  1. 多尺度特征提取

    • 通过并行卷积核捕捉不同尺度的特征。

  2. 计算效率高

    • 使用 1x1 卷积降维,减少计算量。

  3. 性能优异

    • 在 ImageNet 等数据集上表现突出。

代码示例(Inception 模块)

以下是一个简化版的 Inception 模块实现(基于 PyTorch):

import torch
import torch.nn as nn
import torch.nn.functional as Fclass InceptionModule(nn.Module):def __init__(self, in_channels, out_1x1, out_3x3_reduce, out_3x3, out_5x5_reduce, out_5x5, out_pool):super(InceptionModule, self).__init__()# 1x1 卷积分支self.branch1x1 = nn.Conv2d(in_channels, out_1x1, kernel_size=1)# 3x3 卷积分支self.branch3x3 = nn.Sequential(nn.Conv2d(in_channels, out_3x3_reduce, kernel_size=1),nn.Conv2d(out_3x3_reduce, out_3x3, kernel_size=3, padding=1))# 5x5 卷积分支self.branch5x5 = nn.Sequential(nn.Conv2d(in_channels, out_5x5_reduce, kernel_size=1),nn.Conv2d(out_5x5_reduce, out_5x5, kernel_size=5, padding=2))# 池化分支self.branch_pool = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=1, padding=1),nn.Conv2d(in_channels, out_pool, kernel_size=1))def forward(self, x):branch1x1 = self.branch1x1(x)branch3x3 = self.branch3x3(x)branch5x5 = self.branch5x5(x)branch_pool = self.branch_pool(x)# 在通道维度上拼接outputs = [branch1x1, branch3x3, branch5x5, branch_pool]return torch.cat(outputs, 1)# 示例
inception = InceptionModule(in_channels=192, out_1x1=64, out_3x3_reduce=96, out_3x3=128, out_5x5_reduce=16, out_5x5=32,out_pool=32)
input_tensor = torch.randn(1, 192, 28, 28)
output = inception(input_tensor)
print(output.shape)  # 输出形状

2.ResNet + Inception

将Inception模块集成到ResNet中,通常是为了结合卷积神经网络(CNN)的局部特征提取能力和Inception的全局建模能力。

这里添加的位置在每个残差块内部

将 Inception 模块加入 ResNet 中是一种常见的网络设计思路,通常称为 Inception-ResNet。这种设计结合了 Inception 的多尺度特征提取能力和 ResNet 的残差连接,能够进一步提升网络的性能。

以下是如何将 Inception 模块嵌入 ResNet 的实现示例(基于 PyTorch):


实现步骤

  1. 定义 Inception 模块:使用与之前类似的 Inception 模块,但加入残差连接。

  2. 定义 ResNet 块:在 ResNet 的残差块中嵌入 Inception 模块。

  3. 构建完整的网络:将 Inception-ResNet 块堆叠起来,构建完整的网络。

import torch
import torch.nn as nn
import torch.nn.functional as F# 定义 Inception 模块
class InceptionModule(nn.Module):def __init__(self, in_channels, out_1x1, out_3x3_reduce, out_3x3, out_5x5_reduce, out_5x5, out_pool):super(InceptionModule, self).__init__()# 1x1 卷积分支self.branch1x1 = nn.Conv2d(in_channels, out_1x1, kernel_size=1)# 3x3 卷积分支self.branch3x3 = nn.Sequential(nn.Conv2d(in_channels, out_3x3_reduce, kernel_size=1),nn.Conv2d(out_3x3_reduce, out_3x3, kernel_size=3, padding=1))# 5x5 卷积分支self.branch5x5 = nn.Sequential(nn.Conv2d(in_channels, out_5x5_reduce, kernel_size=1),nn.Conv2d(out_5x5_reduce, out_5x5, kernel_size=5, padding=2))# 池化分支self.branch_pool = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=1, padding=1),nn.Conv2d(in_channels, out_pool, kernel_size=1))def forward(self, x):branch1x1 = self.branch1x1(x)branch3x3 = self.branch3x3(x)branch5x5 = self.branch5x5(x)branch_pool = self.branch_pool(x)# 在通道维度上拼接outputs = [branch1x1, branch3x3, branch5x5, branch_pool]return torch.cat(outputs, 1)# 定义 Inception-ResNet 块
class InceptionResNetBlock(nn.Module):def __init__(self, in_channels, out_1x1, out_3x3_reduce, out_3x3, out_5x5_reduce, out_5x5, out_pool):super(InceptionResNetBlock, self).__init__()# Inception 模块self.inception = InceptionModule(in_channels, out_1x1, out_3x3_reduce, out_3x3, out_5x5_reduce, out_5x5,out_pool)# 1x1 卷积用于调整残差连接的通道数self.residual_conv = nn.Conv2d(in_channels, out_1x1 + out_3x3 + out_5x5 + out_pool, kernel_size=1)def forward(self, x):# Inception 模块的输出inception_output = self.inception(x)# 残差连接residual = self.residual_conv(x)# 将 Inception 输出与残差连接相加output = inception_output + residualreturn F.relu(output)# 定义完整的 Inception-ResNet
class InceptionResNet(nn.Module):def __init__(self, num_classes=1000):super(InceptionResNet, self).__init__()# 初始卷积层self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# Inception-ResNet 块self.inception_resnet_block1 = InceptionResNetBlock(in_channels=64,out_1x1=64,out_3x3_reduce=96,out_3x3=128,out_5x5_reduce=16,out_5x5=32,out_pool=32)self.inception_resnet_block2 = InceptionResNetBlock(in_channels=256,  # 64 + 128 + 32 + 32out_1x1=128,out_3x3_reduce=128,out_3x3=192,out_5x5_reduce=32,out_5x5=96,out_pool=64)# 全局平均池化self.avgpool = nn.AdaptiveAvgPool2d((1, 1))# 全连接层self.fc = nn.Linear(480, num_classes)  # 128 + 192 + 96 + 64 = 480def forward(self, x):x = self.conv1(x)x = self.maxpool1(x)x = self.inception_resnet_block1(x)x = self.inception_resnet_block2(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return x# 示例
model = InceptionResNet(num_classes=5)
print(model)input_tensor = torch.randn(1, 3, 224, 224)
output = model(input_tensor)
print(output.shape)  # 输出形状: [1, 1000]

 

关键点说明

  1. Inception 模块:使用多尺度卷积提取特征,并在通道维度上拼接。

  2. 残差连接:在 Inception 模块的输出上加入残差连接,通过 1x1 卷积调整通道数。

  3. 网络结构

    • 初始卷积层用于提取低级特征。

    • 堆叠多个 Inception-ResNet 块以提取高级特征。

    • 使用全局平均池化和全连接层进行分类。


网络结构如下:

InceptionResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (inception_resnet_block1): InceptionResNetBlock(
    (inception): InceptionModule(
      (branch1x1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
      (branch3x3): Sequential(
        (0): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (branch5x5): Sequential(
        (0): Conv2d(64, 16, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      )
      (branch_pool): Sequential(
        (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
        (1): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
      )
    )
    (residual_conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
  )
  (inception_resnet_block2): InceptionResNetBlock(
    (inception): InceptionModule(
      (branch1x1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
      (branch3x3): Sequential(
        (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(128, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (branch5x5): Sequential(
        (0): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(32, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      )
      (branch_pool): Sequential(
        (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
        (1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
      )
    )
    (residual_conv): Conv2d(256, 480, kernel_size=(1, 1), stride=(1, 1))
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=480, out_features=5, bias=True)
)
torch.Size([1, 5])


http://www.ppmy.cn/devtools/167019.html

相关文章

《苍穹外卖》SpringBoot后端开发项目核心知识点与常见问题整理(DAY1 to DAY3)

目录 一、在本地部署并启动Nginx服务1. 解压Nginx压缩包2. 启动Nginx服务3. 验证Nginx是否启动成功: 二、导入接口文档1. 黑马程序员提供的YApi平台2. YApi Pro平台3. 推荐工具:Apifox 三、Swagger1. 常用注解1.1 Api与ApiModel1.2 ApiModelProperty与Ap…

e2studio开发RA4L1(1)---开发板测试

e2studio开发RA4L1.1-- 开发板测试 概述视频教学样品申请产品特性参考程序源码下载硬件准备新建工程工程模板保存工程路径芯片配置工程模板选择时钟设置GPIO口配置UART配置UART属性配置设置e2studio堆栈e2studio的重定向printf设置主程序 概述 RA4L1 评估套件可以使用户能够无…

计算机视觉算法实战——手势识别(主页有源码)

✨个人主页欢迎您的访问 ✨期待您的三连 ✨ ✨个人主页欢迎您的访问 ✨期待您的三连 ✨ ✨个人主页欢迎您的访问 ✨期待您的三连✨ ​ ​​​ 1. 领域简介:手势识别的价值与挑战 手势识别是连接人类自然行为与数字世界的核心交互技术,在智能设备控制、…

docker安装及使用介绍

文章目录 docker安装及使用安装 dockerdocker 常用命令docker 基本命令容器镜像卷网络命令Docker ComposeDocker 系统命令 在docker中安装ROS2 humble拉取 ROS 2 Docker 镜像运行 ROS 2 Docker 容器配置 ROS 2 环境接收外部 ROS 2 话题注意事项 ros1 和docker ros2通信使用 ros…

【贪心算法5】

力扣738.单调递增的数字 链接: link 思路 遇到c[i]>c[i1]则c[i]–,然后就是给c[i1]赋值‘9’;需要注意的是star初值问题,可见注释部分。 class Solution {public int monotoneIncreasingDigits(int n) {String s String.valueOf(n);char[] c s.…

接口测试笔记

7、Mock接口框架 Mock介绍 mock用来模拟接口,这里mock用的是moco框架,moco框架是github上的一个开源项目,可模拟HTTP、HTTPS、Socket协议。 工作原理 Moco的启动及第一个Demo 创建配置文件startup.json启动服务器 java -jar moco-runner…

第2章、WPF窗体及其属性

1、窗体的宽与高。 2、启动窗体设置 3、窗体的启动位置设置 4、窗体图标更换 5、应用程序的图标更改 6、 7、窗体属性汇总: AllowsTransparency 类型: bool 描述: 该属性决定窗口是否可以有透明效果。如果设置为true,窗口的背景必须设置为Transpar…

【反无人机目标检测】DRBD-YOLOv8

DRBD-YOLOv8:A Lightweight and Efficient Anti-UAV Detection Model DRBD-YOLOv8:一种轻量高效的无人机检测模型 0.论文摘要 摘要:由于对无人飞行器(UAV)相关的安全和隐私问题的日益关注,反无人机检测系统…