学习基于pytorch的VGG图像分类 day2

news/2024/11/24 14:06:53/
注:本系列博客在于汇总CSDN的精华帖,类似自用笔记,不做学习交流,方便以后的复习回顾,博文中的引用都注明出处,并点赞收藏原博主.

目录

VGG网络搭建(模型文件)

        1.字典文件配置

         2.提取特征网络结构

        3. VGG类的定义

         4.VGG网络实例化


VGG网络搭建(模型文件)

        1.字典文件配置

#字典文件,对应各个配置,数字对应卷积核的个数,'M'对应最大液化(即maxpool)
cfgs = {'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

         2.提取特征网络结构

#提取特征网络结构
def make_features(cfg: list): #传入对应的列表layers = [] #定义一个空列表,存放每层的结果in_channels = 3 #输入为RGB彩色图片,输入通道为3for v in cfg: #通过for循环遍历列表if v == "M":                                                    #maxpool size = 2,stride = 2layers += [nn.MaxPool2d(kernel_size=2, stride=2)] #创建最大池化下载量程,池化核为2,布局也为2else:                                                           #conv padding = 1,stride = 1conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) #创建卷积操作(输入特征矩阵深度,输出特征矩阵深度(卷积核个数),卷积核为3,填充为1,stride默认为1(不用写))layers += [conv2d, nn.ReLU(True)] #使用ReLU激活函数in_channels = v #输出深度改变成vreturn nn.Sequential(*layers) #通过Sequential函数将列表以非关键字参数的形式传入(*代表非关键字传入)

        3. VGG类的定义

class VGG(nn.Module):def __init__(self, features, num_classes=1000, init_weights=False): #(通过make_features生成的提取特征网络结构,分类的类别个数,是否对网络权重初始化)super(VGG, self).__init__()self.features = featuresself.classifier = nn.Sequential( #生成分类网络nn.Linear(512*7*7, 4096), #全连接层上下的节点个数nn.ReLU(True),  #ReLU函数激活nn.Dropout(p=0.5), #Dropout函数减少过拟合,以50%的比例随机失活神经元nn.Linear(4096, 4096), #第一层和第二层nn.ReLU(True),nn.Dropout(p=0.5),nn.Linear(4096, num_classes) #第二层和第三层,总计3层全连接层,最后连接到输出层,输出num_classes的所需个数)if init_weights: #初始化权重函数self._initialize_weights()def forward(self, x): #正向传播 x就是输入的图像数据 # N x 3 x 224 x 224x = self.features(x) #用features提取特征网络结构# N x 512 x 7 x 7x = torch.flatten(x, start_dim=1) #对输出进行一个展平处理,(start_dim定义从哪个维度开始展平处理)# N x 512*7*7x = self.classifier(x) #输入到分类网络结构return xdef _initialize_weights(self):for m in self.modules(): #遍历网络的每一个子模块if isinstance(m, nn.Conv2d): #遍历到卷积层# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')nn.init.xavier_uniform_(m.weight) #使用xavier函数初始化,初始化卷积核的权重if m.bias is not None: #卷积核采用偏置nn.init.constant_(m.bias, 0) #将偏执初始化为0elif isinstance(m, nn.Linear): #遍历到全连接层,下面同理nn.init.xavier_uniform_(m.weight)# nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)

         4.VGG网络实例化

#实例化VGG网络结构
def vgg(model_name="vgg16", **kwargs):assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name)cfg = cfgs[model_name]model = VGG(make_features(cfg), **kwargs) #通过VGG这个类实现实例化网络,(**可变长度的字典变量)return model

 内容参考来源:

 ​​​​​​使用pytorch搭建VGG网络_哔哩哔哩_bilibili


http://www.ppmy.cn/news/1413002.html

相关文章

【前端捉鬼记】使用nvm切换node版本后再用node -v查看仍然是原来的版本

今天遇到一个诡异的问题,使用nvm切换node版本,明明提示已经切换成功,可是再次查看node版本还是之前的! 尝试了很多办法,比如重新打开一个cmd窗口、切换前执行nvm install version都没成功,直到找到这篇文章…

基于JSP+Mysql+HTml+Css仓库出入库管理系统设计与实现

博主介绍:黄菊华老师《Vue.js入门与商城开发实战》《微信小程序商城开发》图书作者,CSDN博客专家,在线教育专家,CSDN钻石讲师;专注大学生毕业设计教育和辅导。 所有项目都配有从入门到精通的基础知识视频课程&#xff…

web蓝桥杯真题:年度明星项目

代码及注释: //全部数据 var allData [] // 每次需要加载的数量 var num 15// TODO: 请在此补充代码实现项目数据文件和翻译数据文件的请求功能 $.get({url: ./js/all-data.json}).then(res > {allData resloading(allData, num) //初始加载数据 }) $.get(…

Redis Desktop Manager可视化工具

可视化工具 Redis https://www.alipan.com/s/uHSbg14XmsL 提取码: 38cl 点击链接保存,或者复制本段内容,打开「阿里云盘」APP ,无需下载极速在线查看,视频原画倍速播放。 官网下载(不推荐):http…

【热门话题】 Fiddler:一款强大的Web调试代理工具——安装与使用详解

🌈个人主页: 鑫宝Code 🔥热门专栏: 闲话杂谈| 炫酷HTML | JavaScript基础 ​💫个人格言: "如无必要,勿增实体" 文章目录 Fiddler:一款强大的Web调试代理工具——安装与使用详解一、Fiddler的…

4.9QT

完善对话框,点击登录对话框,如果账号和密码匹配,则弹出信息对话框,给出提示”登录成功“,提供一个Ok按钮,用户点击Ok后,关闭登录界面,跳转到其他界面 如果账号和密码不匹配&#xf…

GitHub 仓库 (repository) Pulse - Contributors - Network

GitHub 仓库 [repository] Pulse - Contributors - Network 1. Pulse2. Contributors3. NetworkReferences 1. Pulse 显示该仓库最近的活动信息。该仓库中的软件是无人问津,还是在火热地开发之中,从这里可以一目了然。 2. Contributors 显示对该仓库进…

Java基础知识总结(46)

(1)构造器 构造器的定义: 需要注意的是构造器是一种特殊的方法,其方法名和类名相同,但没有方法返回值,也不用void修饰。 [修饰符] 方法名(形参列表){ •方法体 •} 修饰符:修饰符可以省略&am…