分类神经网络1:VGGNet模型复现

server/2024/10/21 11:26:24/

目录

分类网络的常见形式

VGG网络架构

VGG网络部分实现代码


分类网络的常见形式

常见的分类网络通常由特征提取部分分类部分组成。

特征提取部分实质就是各种神经网络,如VGG、ResNet、DenseNet、MobileNet等。其负责捕获数据的有用信息,一般是通过堆叠多个卷积层和池化层来实现的,这些层有助于检测图像中的边缘、纹理和特征。

分类部分通常是一个全连接层,负责将特征提取部分输出的信息映射到最终的类别或标签。这些全连接层通常包括一个或多个隐藏层,以及一个输出层,其中输出层的节点数量等于任务中的类别数量。

VGG网络架构

论文原址:https://arxiv.org/pdf/1409.1556v6.pdf

VGG 网络是由牛津大学的Visual Geometry Group 开发的,其结构特点在于使用了多个 3x3 的小卷积核,并通过这些小卷积层的重复堆叠来构建网络,从而能够捕捉到更加复杂和抽象的特征表示。VGG 网络的模型结构如下:

VGG网络的核心架构可以分为以下几个部分:

  1. 输入层:VGG网络接受224x224像素的RGB图像作为输入。
  2. 卷积层:网络的前几层由多个卷积层组成,每个卷积层都使用3x3的卷积核来提取图像的特征。这些卷积层后面通常跟着一个2x2 最大池化,用于逐步减小特征图的空间尺寸,同时增加特征深度。
  3. 池化层:在卷积层之后,网络使用最大池化层来降低特征图的空间分辨率,这有助于减少计算量并提取更加抽象的特征。
  4. 全连接层:经过多个卷积和池化层之后,网络的特征图被展平并通过几个全连接层进行处理。全连接层的作用是将学习到的特征映射到最终的分类结果。
  5. 输出层:VGG网络的最后是一个softmax层,它将网络的输出转换为概率分布,以便进行多类别的分类任务。

VGG网络的一个显著特点是其深度,其相关配置信息如下:

VGG系列不同变体内容如下:

  • VGG A:这是一个基础的配置,没有特别独特的设计。
  • VGG A-LRN:在这个版本中,加入了局部响应归一化(LRN),这是一种在AlexNet中首次使用的技术。不过,LRN在当前的深度学习实践中已经较少被采用。
  • VGG B:相较于A版本,B版本增加了两个卷积层,以增强网络的学习能力。
  • VGG C:在B的基础上,C版本进一步增加了三个卷积层,但这些层使用的是1x1的卷积核。1x1卷积核可以看作是对输入特征图进行线性变换,有助于减少参数数量并增加非线性。
  • VGG D:D版本在C版本的基础上做了调整,将1x1的卷积核替换为3x3的卷积核,这个配置后来被称为VGG16,因为它总共有16层。
  • VGG E:在D版本的基础上,E版本进一步增加了三个3x3的卷积层,形成了VGG19,总共有19层。

从图中可以看出,随着网络深度的加深,模型变得更为复杂。通常来说,增加网络的深度可以增加模型的表示能力,使其能够学习到更复杂的特征和模式,从而在某些任务上取得更好的性能。然而,随着网络深度的增加,模型的参数数量也会增加,导致模型的复杂度增加,训练和推理的计算成本也会增加,同时可能会增加过拟合的风险。

VGG网络部分实现代码

废话不多说,直接上干货

import torch
import torch.nn as nn__all__ = ["VGG", "vgg11_bn", "vgg13_bn", "vgg16_bn", "vgg19_bn"]cfg = {'A': [64,     'M', 128,      'M', 256, 256,           'M', 512, 512,           'M', 512, 512,           'M'],'B': [64, 64, 'M', 128, 128, 'M', 256, 256,           'M', 512, 512,           'M', 512, 512,           'M'],'C': [64, 64, 'M', 128, 128, 'M', 256, 256, 256,      'M', 512, 512, 512,      'M', 512, 512, 512,      'M'],'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}class ConvBNReLU(nn.Module):def __init__(self, in_channels, out_channels, stride=1,  kernel_size=3, padding=1):super(ConvBNReLU, self).__init__()self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)self.bn = nn.BatchNorm2d(num_features=out_channels)self.relu = nn.ReLU(inplace=True)def forward(self, x):x = self.conv(x)x = self.bn(x)x = self.relu(x)return xclass VGG(nn.Module):def __init__(self, features, num_classes=1000, init_weights=True):super(VGG, self).__init__()self.features = featuresself.avgpool = nn.AdaptiveAvgPool2d((7, 7))self.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 4096),nn.ReLU(True),nn.Dropout(),nn.Linear(4096, 4096),nn.ReLU(True),nn.Dropout(),nn.Linear(4096, num_classes),)if init_weights:self._initialize_weights()def forward(self, x):for layer in self.features:x = layer(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.classifier(x)return xdef _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)def make_layers(cfg):layers = nn.ModuleList()in_channels = 3for i in cfg:if i == 'M':layers.append(nn.MaxPool2d(kernel_size=2, stride=2))else:layers.append(ConvBNReLU(in_channels=in_channels, out_channels=i))in_channels = ireturn layersdef vgg11_bn(num_classes):model = VGG(make_layers(cfg['A']), num_classes=num_classes)return modeldef vgg13_bn(num_classes):model = VGG(make_layers(cfg['B']), num_classes=num_classes)return modeldef vgg16_bn(num_classes):model = VGG(make_layers(cfg['C']), num_classes=num_classes)return modeldef vgg19_bn(num_classes):model = VGG(make_layers(cfg['D']), num_classes=num_classes)return modelif __name__=='__main__':import torchsummarydevice = 'cuda' if torch.cuda.is_available() else 'cpu'input = torch.ones(2, 3, 224, 224).to(device)net = vgg16_bn(num_classes=4)net = net.to(device)out = net(input)print(out)print(out.shape)torchsummary.summary(net, input_size=(3, 224, 224))# Total params: 134,285,380

这只是一个网络架构部分实现代码,其中 cfg 列表是 VGG 卷积和池化后的通道数,大家可以结合 VGG 的配置信息图一起对比理解。希望对大家有所帮助呀!


http://www.ppmy.cn/server/7677.html

相关文章

react脚手架创建项目,配置别名(alias)

React脚手架项目使用 react-scripts 封装了webpack配置,所以我们需要通过 config-overrides 或者 eject 的方式来修改webpack配置 可以的话 ,创建项目的时候可以使用vite ,我这是老项目屎山 懒得迁移 ,但还得改呀 ## 1. 安装依…

AI助力科研创新与效率双提升:ChatGPT深度科研应用、数据分析及机器学习、AI绘图与高效论文撰写

2022年11月30日,可能将成为一个改变人类历史的日子——美国人工智能开发机构OpenAI推出了聊天机器人ChatGPT3.5,将人工智能的发展推向了一个新的高度。2023年4月,更强版本的ChatGPT4.0上线,文本、语音、图像等多模态交互方式使其在…

for循环的用法

for循环的用法 for 循环是一种在 Python 中用于迭代序列(如列表、元组或字符串)或其他可迭代对象的循环结构。下面是一些常见的 for 循环用法: 遍历列表:使用 for 循环遍历列表中的元素。 fruits ["apple", "b…

CSS显示模式

目录 CSS显示模式简介 CSS显示模式的分类 块元素 行元素 行内块元素 元素显示模式的转换 使块内文字垂直居中的方法 设计简单小米侧边栏(实践) CSS显示模式简介 元素显示模式就是元素(标签)以什么方式进行显示&#xff0…

毕业设计——基于ESP32的智能家居系统(语音识别、APP控制)

ESP32嵌入式单片机实战项目 一、功能演示二、项目介绍1、功能演示2、外设介绍 三、资料获取 一、功能演示 多种控制方式 ① 语音控制 ②APP控制 ③本地按键控制 ESP32嵌入式单片机实战项目演示 二、项目介绍 1、功能演示 这一个基于esp32c3的智能家居控制系统,能实…

华为海思校园招聘-芯片-数字 IC 方向 题目分享——第六套

华为海思校园招聘-芯片-数字 IC 方向 题目分享——第六套 (共9套,有答案和解析,答案非官方,未仔细校正,仅供参考) 部分题目分享,完整版获取(WX:didadidadidida313,加我备注&#x…

实时交互新篇章:WebSocket在Flutter中的应用与实践

WebSocket 协议是一个应用层协议,是一种在单个TCP连接上进行全双工通信的协议。 与传统的HTTP请求不同,WebSocket在建立连接后,双方可以随时发送消息,无需频繁地建立和断开连接。这种特性使得WebSocket成为实时应用的理想选择,如在线游戏、聊天应用和实时数据更新等。 本…

STM32 | USART实战案例

STM32 | 通用同步/异步串行接收/发送器USART带蓝牙(第六天)随着扩展的内容越来越多,很多小伙伴已经忘记了之前的学习内容,然后后面这些都很难理解。STM32合集已在专栏创建,方面大家学习。1、通过电脑串口助手发送数据,控制开发板LED灯 从题目中可以挖掘出,本次使用led、延…