如何搭建VGG网络,实现Mnist数据集的图像分类

news/2024/11/13 3:35:44/

1 问题

如何搭建VGG网络,实现Mnist数据集的图像分类?

2 方法

步骤:

  1. 首先导包

    Import torch
    from torch import nn
  2. VGG11由8个卷积,三个全连接组成,注意池化只改变特征图大小,不改变通道数

    class MyNet(nn.Module):
       def __init__(self) -> None:
           super().__init__()
           #(1)conv3-64
           self.conv1 = nn.Conv2d(
               in_channels=3,
               out_channels=64,
               kernel_size=3,
               stride=1,
               padding=1 #! 不改变特征图的大小
           )
           #! 池化只改变特征图大小,不改变通道数
           self.max_pool_1 = nn.MaxPool2d(2)
           #(2)conv3-128
           self.conv2 = nn.Conv2d(
               in_channels=64,
               out_channels=128,
               kernel_size=3,
               stride=1,
               padding=1
           )
           self.max_pool_2 = nn.MaxPool2d(2)
           #(3) conv3-256,conv3-256
           self.conv3_1 = nn.Conv2d(
               in_channels=128,
               out_channels=256,
               kernel_size=3,
               stride=1,
               padding=1)
           self.conv3_2 = nn.Conv2d(
               in_channels=256,
               out_channels=256,
               kernel_size=3,
               stride=1,
               padding=1
           )
           self.max_pool_3 = nn.MaxPool2d(2)
           #(4)conv3-512,conv3-512
           self.conv4_1 = nn.Conv2d(
               in_channels=256,
               out_channels=512,
               kernel_size=3,
               stride=1,
               padding=1
           )
           self.conv4_2 = nn.Conv2d(
               in_channels=512,
               out_channels=512,
               kernel_size=3,
               stride=1,
               padding=1
           )
           self.max_pool_4 = nn.MaxPool2d(2)
           #(5)conv3-512,conv3-512
           self.conv5_1 = nn.Conv2d(
               in_channels=512,
               out_channels=512,
               kernel_size=3,
               stride=1,
               padding=1
           )
           self.conv5_2 = nn.Conv2d(
               in_channels=512,
               out_channels=512,
               kernel_size=3,
               stride=1,
               padding=1
           )
           self.max_pool_5 = nn.MaxPool2d(2)
           #(6)
           self.fc1 = nn.Linear(25088,4096)
           self.fc2 = nn.Linear(4096,4096)
           self.fc3 = nn.Linear(4096,1000)
       def forward(self,x):
           x = self.conv1(x)
           print(x.shape)
           x = self.max_pool_1(x)
           print(x.shape)
           x = self.conv2(x)
           print(x.shape)
           x = self.max_pool_2(x)
           print(x.shape)
           x = self.conv3_1(x)
           print(x.shape)
           x = self.conv3_2(x)
           print(x.shape)
           x = self.max_pool_3(x)
           print(x.shape)
           x = self.conv4_1(x)
           print(x.shape)
           x = self.conv4_2(x)
           print(x.shape)
           x = self.max_pool_4(x)
           print(x.shape)
           x = self.conv5_1(x)
           print(x.shape)
           x = self.conv5_2(x)
           print(x.shape)
           x = self.max_pool_5(x)
           print(x.shape)
           x = torch.flatten(x,1)
           print(x.shape)
           x = self.fc1(x)
           print(x.shape)
           x = self.fc2(x)
           print(x.shape)
           out = self.fc3(x)
           return out
  3. 给定x查看最后结果

x = torch.rand(128,3,224,224)
net = MyNet()
out = net(x)
print(out.shape)
#torch.Size([128, 1000])

3 结语

   通过本周学习让我学会了VGG11网络,从实验中我遇到的容易出错的地方是卷积的in_features和out_features容易出错,尺寸不对的时候就会报错,在多个卷积的情况下尤其需要注意,第二点容易出错的地方是卷积以及池化所有结束后,一定要使用torch.flatten进行拉伸,第三点容易出错的地方是fc1的in_features,这个我通过使用断点的方法,得到fc1前一步的size值,从而得到in_features的值,从中收获颇深。


http://www.ppmy.cn/news/18406.html

相关文章

QEMU之一调试uboot(vexpress-a9)

u-boot版本:u-boot-2017.05开发板:vexpress-a9(没办法,目前看到的都是这个开发板,想QEMU调试tiny210,一直没看到怎么修改qemu)编译u-boot:make ARCHarm CROSS_COMPILEarm-linux-gnueabi- vexpre…

深度学习入门基础CNN系列——批归一化(Batch Normalization)和丢弃法(dropout)

想要入门深度学习的小伙伴们,可以了解下本博主的其它基础内容: 🏠我的个人主页 🚀深度学习入门基础CNN系列——卷积计算 🌟深度学习入门基础CNN系列——填充(padding)与步幅(stride&…

【NI Multisim 14.0虚拟仪器设计——放置虚拟仪器仪表(频率计数器)】

目录 序言 🏮放置虚拟仪器仪表🏮 🧧频率计数器🧧 🥳🥳(1)“测量”选项组:参数测量区。 🥳🥳(2)“耦合”选项组:用于选择电流耦合方…

十六进制转八进制(蓝桥杯基础练习C/C++)

我首先想到的就是十六进制转十进制&#xff0c;十进制转八进制&#xff0c;毕竟这样的方法是最常见的&#xff0c;但始终出现报错。 我想可能是int能储存的数范围太小了&#xff0c;就尝试用long long存储&#xff0c;结果还是报错。 #include <bits/stdc.h> using nam…

[ESP][驱动]GT911 ESP系列驱动

GT911ForESP GT911在ESP系列上的驱动&#xff0c;基于IDF5.0&#xff0c;ESP32S3编写 本库使用面向对象思想编写&#xff0c;可创建多设备多实例 Github&#xff0c;Gitee同步更新&#xff0c;Gitee仅作为下载仓库&#xff0c;提交Issue和Pull request请到Github Github: h…

Java 对象处理流(ObjectOutputStream\ObjectInputStream)

文章目录前言什么是对象流&#xff1f;基本介绍ObjectOutputStreamObjectInputStream对象处理流的使用细节前言 处理流&#xff1a;是对一个已存在的流进行处理和封装&#xff0c;通过所封装的流的功能调用实现对数据的操作。而处理流中也有不同的分类&#xff0c;此片介绍的是…

蓝桥杯重点(C/C++)(随时更新)

目录 1 技巧 1.1 取消同步&#xff08;节约时间&#xff0c;甚至能多骗点分&#xff0c;最好每个程序都写上&#xff09; 1.2 万能库&#xff08;可能会耽误编译时间&#xff0c;但是省脑子&#xff09; 1.3 蓝桥杯return 0千万别忘了写&#xff01;&#xff01; 1.4 …

【JavaWeb】前端开发三剑客之CSS(上)

✨哈喽&#xff0c;进来的小伙伴们&#xff0c;你们好耶&#xff01;✨ &#x1f6f0;️&#x1f6f0;️系列专栏:【JavaWeb】 ✈️✈️本篇内容:CSS从零开始学习&#xff01; &#x1f680;&#x1f680;代码托管平台github&#xff1a;JavaWeb代码存放仓库&#xff01; ⛵⛵作…