PyTorch学习笔记(十三)——现有网络模型的使用及修改

news/2024/11/24 4:27:22/

 以分类模型的VGG为例

 

vgg16_false = torchvision.models.vgg16(weights=False)
vgg16_true = torchvision.models.vgg16(weights=True)
  • 设置为 False 的情况,相当于网络模型中的参数都是初始化的、默认的
  • 设置为 True 时,网络模型中的参数在数据集上是训练好的,能达到比较好的效果
print(vgg16_true)
VGG((features): Sequential(
# 输入图片先经过卷积,输入是3通道的、输出是64通道的,卷积核大小是3×3的(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# 非线性(1): ReLU(inplace=True)
# 卷积、非线性、池化...(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace=True)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace=True)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace=True)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace=True)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace=True)(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): ReLU(inplace=True)(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace=True)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace=True)(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(25): ReLU(inplace=True)(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(27): ReLU(inplace=True)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace=True)(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace=True)(2): Dropout(p=0.5, inplace=False)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace=True)(5): Dropout(p=0.5, inplace=False)
# 最后线性层输出为1000(vgg16也是一个分类模型,能分出1000个类别)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

CIFAR10 把数据分成了10类,而 vgg16 模型把数据分成了 1000 类,如何应用这个网络模型呢?

  • 方法1:把最后线性层的 out_features 从1000改为10
  • 方法2:在最后的线性层下面再加一层,in_features为1000,out_features为10

利用现有网络去改动它的结构,避免写 vgg16。很多框架会把 vgg16 当做前置的网络结构,提取一些特殊的特征,再在后面加一些网络结构,实现功能。

方法2:添加

vgg16_true.classifier.add_module("add_linear",nn.Linear(1000,10))
print(vgg16_true)

 方法1:修改

vgg16_false.classifier[6] = nn.Linear(4096,10)
print(vgg16_false)


http://www.ppmy.cn/news/1043186.html

相关文章

高品质音乐下载命令行工具Musicn

又到了小苏同学的生日🎂,宝贝,生日快乐!祝永远健康、快乐、心想事成! 什么是 Musicn ? Musicn 是一个可播放及下载高品质🎵音乐🎵的命令行工具。支持咪咕、酷我、酷狗和网易云的服务…

C的进阶C++学习方向

(꒪ꇴ꒪ ),Hello我是祐言QAQ我的博客主页:C/C语言,Linux基础,ARM开发板,软件配置等领域博主🌍快上🚘,一起学习,让我们成为一个强大的攻城狮!送给自己和读者的…

最小化安装移动云大云操作系统--BCLinux-R8-U8-Server-x86_64-230802版

CentOS 结束技术支持,转为RHEL的前置stream版本后,国内开源Linux服务器OS生态转向了开源龙蜥和开源欧拉两大开源社区,对应衍生出了一系列商用Linux服务器系统。BC-Linux V8.8是中国移动基于龙蜥社区Anolis OS 8.8版本深度定制的企业级X86服务…

管理类联考——逻辑——真题篇——按知识分类——汇总篇——二、论证逻辑——削弱—— 第二节 因果判断

文章目录 第二节 削弱-题型2-因果判断-论证为事实-对于两个既定事实A和B,判断A是B的原因的过程称为因果判断题-削弱质疑-事实-因果判断-分类1-因果倒置真题(2010-34)-削弱质疑-事实-因果判断-分类1-因果倒置真题(2020-27)-削弱质疑-事实-因果判断-分类1-因果倒置-完全推翻…

判断平面中两射线是否相交的高效方法

1. 简介 最近在工作中遇到判断平面内两射线是否相交的问题。 对于这个问题的解决,常规的方法是将两条射线拓展为直线,计算直线的交点,而后判断交点是否在射线上。 这种方法,在思路上较为直观,也易于理解。然后,该方法在计算量上相对较大。对于少量射线间的交点计算尚可…

stm32控制蜂鸣器源代码(附带proteus线路图)

说明: 1 PB0输出0时,蜂鸣器发生; 2 蜂鸣器电阻值如果太大会导致电流太小,发不出声音; 3蜂鸣器额定电压需要设置得低一点,可以是2V,但不能高于3V,这更右上角的电阻值有关系&#x…

【C++深入浅出】初识C++中篇(引用、内联函数)

目录 一. 前言 二. 引用 2.1 引用的概念 2.2 引用的使用 2.3 引用的特性 2.4 常引用 2.5 引用的使用场景 2.6 传值、传引用效率比较 2.7 引用和指针的区别 三. 内联函数 3.1 内联函数的概念 3.2 内联函数的特性 一. 前言 上期说道,C是在C的基础之上&…

ELKstack-日志收集案例

由于实验环境限制,将 filebeat 和 logstash 部署在 tomcat-server-nodeX,将 redis 和 写 ES 集群的 logstash 部署在 redis-server,将 HAproxy 和 Keepalived 部署在 tomcat-server-nodeX。将 Kibana 部署在 ES 集群主机。 环境:…