在使用Pytorch搭建自己的神经网络框架时,经常需要使用Pytorch中内置的torchvision.models
中的模型作为特征提取的Backbone,然后再在这个基础上进行更加复杂的网络搭建。
在这里以使用Pytorch中内置的Resnet18为例,如何作为Backbone层进行使用,看以下示例代码
# -*- coding: utf-8 -*-
import torch.nn as nn
import torchvisionclass Resnet18Backbone(nn.Module):def __init__(self):super(Resnet18Backbone, self).__init__()self.model = torchvision.models.resnet18(pretrained=True)self.model.fc = nn.Sequential()def forward(self, x):x = self.model.conv1(x)x = self.model.bn1(x)x = self.model.relu(x)x = self.model.maxpool(x)x = self.model.layer1(x)x = self.model.layer2(x)x = self.model.layer3(x)x = self.model.layer4(x)x = self.model.avgpool(x)return x
使用上述代码,如果输入Tensor的维度为[1,3,244,244],fowward
输出的Tensor的维度为[1,512,1,1],如果我们需要输出的Tensor维度为[1,512],需要squeeze
相应的维度,修改后的代码如下
# -*- coding: utf-8 -*-
import torch.nn as nn
import torchvisionclass Resnet18Backbone(nn.Module):def __init__(self):super(Resnet18Backbone, self).__init__()self.model = torchvision.models.resnet18(pretrained=True)self.model.fc = nn.Sequential()def forward(self, x):x = self.model.conv1(x)x = self.model.bn1(x)x = self.model.relu(x)x = self.model.maxpool(x)x = self.model.layer1(x)x = self.model.layer2(x)x = self.model.layer3(x)x = self.model.layer4(x)x = self.model.avgpool(x)x = x.squeeze(2).squeeze(2)return x
好了,上述代码的Resnet18Backbone
可以作为网络中的一层进行使用,这里都是以ResNet的Adaptive Average Pooling层作为backbone的输出层,如果我们仅仅需要前面的卷积层作为输出层,可以参考以下代码。
比如,如果我们要使用ResNet18的Adaptive Average Pooling作为backbone的输出层,我们可以这样写,
# backboneif backbone_name == 'resnet_18':resnet_net = torchvision.models.resnet18(pretrained=True)modules = list(resnet_net.children())[:-1]backbone = nn.Sequential(*modules)backbone.out_channels = 512elif backbone_name == 'resnet_34':resnet_net = torchvision.models.resnet34(pretrained=True)modules = list(resnet_net.children())[:-1]backbone = nn.Sequential(*modules)backbone.out_channels = 512elif backbone_name == 'resnet_50':resnet_net = torchvision.models.resnet50(pretrained=True)modules = list(resnet_net.children())[:-1]backbone = nn.Sequential(*modules)backbone.out_channels = 2048elif backbone_name == 'resnet_101':resnet_net = torchvision.models.resnet101(pretrained=True)modules = list(resnet_net.children())[:-1]backbone = nn.Sequential(*modules)backbone.out_channels = 2048elif backbone_name == 'resnet_152':resnet_net = torchvision.models.resnet152(pretrained=True)modules = list(resnet_net.children())[:-1]backbone = nn.Sequential(*modules)backbone.out_channels = 2048elif backbone_name == 'resnet_50_modified_stride_1':resnet_net = resnet50(pretrained=True)modules = list(resnet_net.children())[:-1]backbone = nn.Sequential(*modules)backbone.out_channels = 2048elif backbone_name == 'resnext101_32x8d':resnet_net = torchvision.models.resnext101_32x8d(pretrained=True)modules = list(resnet_net.children())[:-1]backbone = nn.Sequential(*modules)backbone.out_channels = 2048
如果我们仅仅只是需要前面的卷积层作为backbone,我们可以这样写
# backboneif backbone_name == 'resnet_18':resnet_net = torchvision.models.resnet18(pretrained=True)modules = list(resnet_net.children())[:-2]backbone = nn.Sequential(*modules)elif backbone_name == 'resnet_34':resnet_net = torchvision.models.resnet34(pretrained=True)modules = list(resnet_net.children())[:-2]backbone = nn.Sequential(*modules)elif backbone_name == 'resnet_50':resnet_net = torchvision.models.resnet50(pretrained=True)modules = list(resnet_net.children())[:-2]backbone = nn.Sequential(*modules)elif backbone_name == 'resnet_101':resnet_net = torchvision.models.resnet101(pretrained=True)modules = list(resnet_net.children())[:-2]backbone = nn.Sequential(*modules)elif backbone_name == 'resnet_152':resnet_net = torchvision.models.resnet152(pretrained=True)modules = list(resnet_net.children())[:-2]backbone = nn.Sequential(*modules)elif backbone_name == 'resnet_50_modified_stride_1':resnet_net = resnet50(pretrained=True)modules = list(resnet_net.children())[:-2]backbone = nn.Sequential(*modules)elif backbone_name == 'resnext101_32x8d':resnet_net = torchvision.models.resnext101_32x8d(pretrained=True)modules = list(resnet_net.children())[:-2]backbone = nn.Sequential(*modules)
参考链接
- https://stackoverflow.com/questions/58362892/resnet-18-as-backbone-in-faster-r-cnn
有兴趣可以访问我的个人站:https://www.stubbornhuang.com/