Pytorch - 使用pytorch自带的Resnet作为网络的backbone

news/2024/11/24 9:40:19/

在使用Pytorch搭建自己的神经网络框架时,经常需要使用Pytorch中内置的torchvision.models中的模型作为特征提取的Backbone,然后再在这个基础上进行更加复杂的网络搭建。

在这里以使用Pytorch中内置的Resnet18为例,如何作为Backbone层进行使用,看以下示例代码

# -*- coding: utf-8 -*-
import torch.nn as nn
import torchvisionclass Resnet18Backbone(nn.Module):def __init__(self):super(Resnet18Backbone, self).__init__()self.model = torchvision.models.resnet18(pretrained=True)self.model.fc = nn.Sequential()def forward(self, x):x = self.model.conv1(x)x = self.model.bn1(x)x = self.model.relu(x)x = self.model.maxpool(x)x = self.model.layer1(x)x = self.model.layer2(x)x = self.model.layer3(x)x = self.model.layer4(x)x = self.model.avgpool(x)return x

使用上述代码,如果输入Tensor的维度为[1,3,244,244],fowward输出的Tensor的维度为[1,512,1,1],如果我们需要输出的Tensor维度为[1,512],需要squeeze相应的维度,修改后的代码如下

# -*- coding: utf-8 -*-
import torch.nn as nn
import torchvisionclass Resnet18Backbone(nn.Module):def __init__(self):super(Resnet18Backbone, self).__init__()self.model = torchvision.models.resnet18(pretrained=True)self.model.fc = nn.Sequential()def forward(self, x):x = self.model.conv1(x)x = self.model.bn1(x)x = self.model.relu(x)x = self.model.maxpool(x)x = self.model.layer1(x)x = self.model.layer2(x)x = self.model.layer3(x)x = self.model.layer4(x)x = self.model.avgpool(x)x = x.squeeze(2).squeeze(2)return x

好了,上述代码的Resnet18Backbone可以作为网络中的一层进行使用,这里都是以ResNet的Adaptive Average Pooling层作为backbone的输出层,如果我们仅仅需要前面的卷积层作为输出层,可以参考以下代码。

比如,如果我们要使用ResNet18的Adaptive Average Pooling作为backbone的输出层,我们可以这样写,

# backboneif backbone_name == 'resnet_18':resnet_net = torchvision.models.resnet18(pretrained=True)modules = list(resnet_net.children())[:-1]backbone = nn.Sequential(*modules)backbone.out_channels = 512elif backbone_name == 'resnet_34':resnet_net = torchvision.models.resnet34(pretrained=True)modules = list(resnet_net.children())[:-1]backbone = nn.Sequential(*modules)backbone.out_channels = 512elif backbone_name == 'resnet_50':resnet_net = torchvision.models.resnet50(pretrained=True)modules = list(resnet_net.children())[:-1]backbone = nn.Sequential(*modules)backbone.out_channels = 2048elif backbone_name == 'resnet_101':resnet_net = torchvision.models.resnet101(pretrained=True)modules = list(resnet_net.children())[:-1]backbone = nn.Sequential(*modules)backbone.out_channels = 2048elif backbone_name == 'resnet_152':resnet_net = torchvision.models.resnet152(pretrained=True)modules = list(resnet_net.children())[:-1]backbone = nn.Sequential(*modules)backbone.out_channels = 2048elif backbone_name == 'resnet_50_modified_stride_1':resnet_net = resnet50(pretrained=True)modules = list(resnet_net.children())[:-1]backbone = nn.Sequential(*modules)backbone.out_channels = 2048elif backbone_name == 'resnext101_32x8d':resnet_net = torchvision.models.resnext101_32x8d(pretrained=True)modules = list(resnet_net.children())[:-1]backbone = nn.Sequential(*modules)backbone.out_channels = 2048

如果我们仅仅只是需要前面的卷积层作为backbone,我们可以这样写

# backboneif backbone_name == 'resnet_18':resnet_net = torchvision.models.resnet18(pretrained=True)modules = list(resnet_net.children())[:-2]backbone = nn.Sequential(*modules)elif backbone_name == 'resnet_34':resnet_net = torchvision.models.resnet34(pretrained=True)modules = list(resnet_net.children())[:-2]backbone = nn.Sequential(*modules)elif backbone_name == 'resnet_50':resnet_net = torchvision.models.resnet50(pretrained=True)modules = list(resnet_net.children())[:-2]backbone = nn.Sequential(*modules)elif backbone_name == 'resnet_101':resnet_net = torchvision.models.resnet101(pretrained=True)modules = list(resnet_net.children())[:-2]backbone = nn.Sequential(*modules)elif backbone_name == 'resnet_152':resnet_net = torchvision.models.resnet152(pretrained=True)modules = list(resnet_net.children())[:-2]backbone = nn.Sequential(*modules)elif backbone_name == 'resnet_50_modified_stride_1':resnet_net = resnet50(pretrained=True)modules = list(resnet_net.children())[:-2]backbone = nn.Sequential(*modules)elif backbone_name == 'resnext101_32x8d':resnet_net = torchvision.models.resnext101_32x8d(pretrained=True)modules = list(resnet_net.children())[:-2]backbone = nn.Sequential(*modules)

参考链接

  • https://stackoverflow.com/questions/58362892/resnet-18-as-backbone-in-faster-r-cnn

有兴趣可以访问我的个人站:https://www.stubbornhuang.com/


http://www.ppmy.cn/news/10867.html

相关文章

环形缓冲区

文章目录一. 什么是环形缓冲区?二、实现环形缓冲区:三、环形缓冲区示例代码:总结一. 什么是环形缓冲区? 环形缓冲区 是一段 先进先出 的循环缓冲区,有一定的大小,我们可以把它抽象理解为一块环形的内存。 …

2023年底,我要通过这5点,实现博客访问量500W

说实话,这真的是一个非常高远的flag,因为我目前只有35W,但根据我2个月前还是12W的访问量,我觉得我还是可以拼一把的,在这里我想向大家分享一下我的计划,如何达成2023年底,博客访问量达到500W的K…

java 微服务之docker基础入门 docker部署 镜像相关命令 容器命令 数据卷 DockerCompose Docker镜像仓库

初识Docker 项目部署的问题 什么是Docker 不同环境的操作系统不同,Docker如何解决?我们先来了解下操作系统结构 Docker与虚拟机 虚拟机是在一个系统内,运行另外一个系统 镜像和容器 镜像(Image):Docker将…

通讯电平转换电路中的经典设计

今天给大家分享几个通讯电平转换电路。 有初学者问:什么是电平转换?举个例子,比如下面这个电路: 单片机的工作电压是5V,蓝牙模块的工作电压是3.3V,两者之间要进行通讯,TXD和RXD引脚就要进行连接…

ClickHouse 挺快,esProc SPL 更快

开源分析数据库ClickHouse以快著称,真的如此吗?我们通过对比测试来验证一下。 ClickHouse vs Oracle 先用ClickHouse(简称CH)、Oracle数据库(简称ORA)一起在相同的软硬件环境下做对比测试。测试基准使用国…

Spring Boot 3.0横空出世,快来看看是不是该升级了

文章目录简介对JAVA17和JAVA19的支持recordText BlocksSwitch Expressionsinstanceof模式匹配Sealed Classes and Interfaces迁移到Jakarta EEGraalVM Native Image Support对Micrometer的支持其他的一些改动简介 Spring boot 3.0于2022年11月正式发布了,这次的发布…

设计模式之职责链模式

设计模式之职责链模式 1)职责链模式(Chain Of Responsibility Pattern),又叫责任链模式,为请求创建了一个接受者对象的链。这种模式对请求的发送者和接收者进行解耦。 2)职责链模式通常每个接收者都包含另…

php宝塔搭建部署实战响应式园林景观设计公司网站系统源码

大家好啊,我是测评君,欢迎来到web测评。 本期给大家带来一套php开发的响应式园林景观设计公司网站系统源码,感兴趣的朋友可以自行下载学习。 技术架构 PHP7.2 nginx mysql5.7 JS CSS HTMLcnetos7以上 宝塔面板 文字搭建教程 下载源码…