pytorch backbone

ops/2024/9/25 21:26:45/

1 简介

在PyTorch深度学习中,预训练backbone(骨干网络)是一个常见的做法,特别是在处理图像识别、目标检测、图像分割等任务时。预训练backbone通常是指在大型数据集(如ImageNet)上预先训练好的卷积神经网络(CNN)模型,这些模型能够提取图像中的通用特征,这些特征在多种任务中都是有用的。

1. 常见的预训练Backbone

以下是一些在PyTorch中常用的预训练backbone:

  • ResNet:由何恺明等人提出的深度残差网络,通过引入残差连接解决了深层网络训练中的梯度消失或梯度爆炸问题。ResNet系列包括ResNet-18、ResNet-34、ResNet-50、ResNet-101、ResNet-152等,数字表示网络的层数。
  • VGG:由牛津大学的Visual Geometry Group提出,特点是使用了多个小卷积核(如3x3)的卷积层和池化层来构建深层网络。VGG系列包括VGG16、VGG19等。
  • MobileNet:专为移动和嵌入式设备设计的轻量级网络,通过深度可分离卷积减少了计算量和模型大小。
  • DenseNet:通过密集连接(dense connections)提高了信息流动和梯度传播效率,进一步增强了特征重用。
  • EfficientNet:通过同时缩放网络的深度、宽度和分辨率来优化网络,实现了在保持模型效率的同时提高准确率。

2. 如何使用预训练Backbone

在PyTorch中,使用预训练backbone通常涉及以下几个步骤:

  1. 导入模型:使用PyTorch的torchvision.models模块导入所需的预训练模型。

    python">import torchvision.models as models  # 导入预训练的ResNet50模型  
    resnet50 = models.resnet50(pretrained=True)
    print(resnet50)
  2. 修改模型:根据需要修改模型的最后几层以适应特定的任务(如分类任务中的类别数)。

    python"># 假设我们有一个100类的分类任务  
    num_ftrs = resnet50.fc.in_features  
    resnet50.fc = torch.nn.Linear(num_ftrs, 100)
  3. 冻结backbone:在训练时,可以选择冻结backbone的参数,只训练新添加的层(如分类层),这有助于加快训练速度并防止过拟合。

    python">for param in resnet50.parameters():  param.requires_grad = False  # 只对新添加的层设置requires_grad=True  
    resnet50.fc.parameters().requires_grad = True
  4. 训练模型:使用适当的数据集和训练策略来训练模型。

  5. 评估模型:在测试集上评估模型的性能。

3. 注意事项

  • 使用预训练权重时,应确保输入图像的预处理(如大小调整、归一化等)与预训练时使用的预处理一致。
  • 冻结backbone时,应确保模型的其余部分(如新添加的层)有足够的容量来学习任务特定的特征。
  • 在某些情况下,解冻backbone的一部分或全部并在目标数据集上进行微调可能会获得更好的性能。

通过以上步骤,可以在PyTorch中有效地利用预训练backbone来解决各种计算机视觉任务。

2 查看模型源码

想查看models.resnet50的源码,可以点击查看pytorch中的官方注释,可以看到源码链接为

vision/torchvision/models/resnet.py at main · pytorch/vision · GitHub

这样就可以看到 class ResNet(nn.Module) 的定义

3 查看权重参数

在PyTorch中,查看深度学习预训练backbone的权重参数可以通过几种方法实现。以下是一些常用的步骤和方法:

1. 加载预训练模型

首先,你需要使用torchvision.models模块加载所需的预训练模型。例如,加载一个预训练的ResNet50模型:

python">import torchvision.models as models  # 加载预训练的ResNet50模型  
resnet50 = models.resnet50(pretrained=True)

2. 查看模型参数

方法一:使用model.parameters()

model.parameters()方法返回一个生成器,包含模型的所有参数(权重和偏置)。但是,这个方法不会直接显示参数的名称,只适合在训练循环中迭代参数。

方法二:使用model.named_parameters()

model.named_parameters()方法返回一个生成器,其中每个元素都是一个包含参数名称和参数本身的元组。这是查看模型每层权重参数及其名称的最直接方法。

python">for name, param in resnet50.named_parameters():  print(name, param.size())

这段代码会遍历模型的所有参数,并打印出每个参数的名称和尺寸。

3. 专注于特定层的参数

如果你只对backbone中的特定层感兴趣,可以进一步筛选named_parameters()的输出。例如,如果你想看ResNet50中第一个卷积层的参数:

python">for name, param in resnet50.named_parameters():  if 'conv1' in name:  print(name, param.size())

4. 注意事项

  • 当查看模型参数时,请确保你了解模型的架构,以便正确地解释参数的名称和尺寸。
  • 预训练模型的权重是在特定数据集(如ImageNet)上训练的,因此这些权重可能对你的特定任务有所帮助,但也可能需要进一步的微调。
  • 如果你的模型是基于预训练模型进行修改的(例如,更改了最后一层以匹配不同的类别数),请确保你理解这些修改如何影响模型的参数。

5. 示例输出

运行上述代码(针对ResNet50的named_parameters())将输出类似以下的信息(输出将非常长,这里只展示部分):

python">conv1.weight torch.Size([64, 3, 7, 7])  
conv1.bias torch.Size([64])  
bn1.weight torch.Size([64])  
bn1.bias torch.Size([64])  
bn1.running_mean torch.Size([64])  
bn1.running_var torch.Size([64])  
...

这表示conv1层有一个权重参数(大小为[64, 3, 7, 7])和一个偏置参数(大小为[64]),以及对应的批量归一化层的权重、偏置、运行均值和运行方差等参数。

4 常见bakcbone以及适用业务

在PyTorch中,预训练的backbone模型是深度学习领域中的重要组成部分,它们为各种任务提供了强大的特征提取能力。然而,由于PyTorch本身是一个灵活的深度学习框架,它并不直接提供所有可能的预训练backbone模型,而是由社区和研究者基于PyTorch框架实现并分享。以下是一些常见的PyTorch预训练backbone模型,以及它们的优劣和适用场景:

1. ResNet(残差网络)

优势

  • 引入了残差连接,解决了深层网络训练中的梯度消失或梯度爆炸问题。
  • 在多个计算机视觉任务中表现出色,如图像分类、目标检测等。

劣势

  • 对于某些特定任务,可能不是最优选择,需要根据任务特点进行调整。

适用场景

  • 图像分类、目标检测、语义分割等。

2. VGG

优势

  • 结构简单明了,易于理解和实现。
  • 在多个基准数据集上取得了良好的性能。

劣势

  • 参数量较大,计算成本较高。

适用场景

  • 早期深度学习研究和教学。

3. MobileNet

优势

  • 专为移动和嵌入式设备设计,具有较小的模型大小和较快的推理速度。
  • 采用了深度可分离卷积等技术,减少了计算量和参数量。

劣势

  • 相比于其他大型模型,可能在某些复杂任务上的精度稍低。

适用场景

  • 移动应用、嵌入式设备上的实时图像处理和分类。

4. DenseNet(密集连接网络)

优势

  • 每一层都直接与后面的所有层相连,增强了特征传播和复用。
  • 在多个数据集上取得了比ResNet更好的性能。

劣势

  • 参数量和计算量相对较大。

适用场景

  • 需要高精度和强特征表达能力的任务,如医学图像分析。

5. EfficientNet

优势

  • 通过复合缩放方法(compound scaling)平衡了网络的深度、宽度和分辨率,实现了在有限资源下的最佳性能。
  • 在多个计算机视觉任务中取得了SOTA(state-of-the-art)性能。

劣势

  • 需要根据具体任务进行微调以获得最佳性能。

适用场景

  • 追求极致性能的计算机视觉任务,如大规模图像分类和检测。

6. YOLOv5的Backbone(如CSPDarknet)

优势

  • 专为目标检测任务设计,具有较快的推理速度和较高的检测精度。
  • 采用了CSPNet等结构,进一步提升了网络性能。

劣势

  • 相比于专门的分类网络,可能在分类任务上的性能稍逊。

适用场景

  • 实时目标检测任务,如自动驾驶、视频监控等。

请注意,以上列举的backbone模型并不全面,PyTorch社区和研究者们不断在推出新的模型和架构。此外,每种模型都有其特定的优势和劣势,以及适用的场景。在选择模型时,需要根据具体任务的需求、计算资源等因素进行综合考虑。

对于PyTorch中预训练backbone模型的获取,可以通过PyTorch的官方模型库(如torchvision)或第三方库(如timmpretrainedmodels等)来获取。这些库提供了大量预训练的backbone模型,并支持多种加载和使用方式。

5 从backbone提取特征图(☆)

python">import torch
import torch.nn as nn
import torchvision.models as models
from collections import OrderedDictclass ResNet18(nn.Module):def __init__(self):super().__init__()self.resnet18 = models.resnet18(pretrained=True)def forward(self, x):features = OrderedDict()x = self.resnet18.conv1(x)x = self.resnet18.bn1(x)x = self.resnet18.relu(x)x = self.resnet18.maxpool(x)features['3'] = xx = self.resnet18.layer1(x)x = self.resnet18.layer2(x)features['2'] = xx = self.resnet18.layer3(x)features['1'] = xx = self.resnet18.layer4(x)features['0'] = xreturn featuresmodel = ResNet18()
input = torch.ones(1, 3, 640, 640)  # NCHW
y = model(input)
for key, value in y.items():print(key, value.shape)

打印信息

3 torch.Size([1, 64, 160, 160])
2 torch.Size([1, 128, 80, 80])
1 torch.Size([1, 256, 40, 40])
0 torch.Size([1, 512, 20, 20])


http://www.ppmy.cn/ops/85994.html

相关文章

后端开发工程师vue2初识的学习

博客主页:音符犹如代码系列专栏:JavaWeb关注博主,后期持续更新系列文章如果有错误感谢请大家批评指出,及时修改感谢大家点赞👍收藏⭐评论✍ 什么是Vue? Vue (通常指 Vue.js)是一个用…

【概率论】第一章:概率论基本概念

文章目录 一. 随机事件与空间样本二. 事件间的关系与事件的运算三. 概率、条件概率、事件独立性与五大公式1. 概率2. 条件概率3. 事件独立性4. 五大公式 四. 古典型、几何型概率、伯努利试验 确定现象:磁极同性相斥 随机现象:在单次实验结果中呈现出不确…

Java 不可变Map练习 (2024.7.28)

CollectionExercise3 package CollectionExercise20240728;import java.util.HashMap; import java.util.Map; import java.util.Set;public class CollectionExercise3 {public static void main(String[] args) {// 不可变的Map集合// Map中键是不可以重复的// Map中的of方法…

GPT-4o Mini 模型的性能与成本优势全解析

GPT-4o Mini 模型的性能与成本优势全解析 📈 🌟 GPT-4o Mini 模型的性能与成本优势全解析 📈摘要引言正文内容GPT-4o Mini 模型简介 🚀性能测试与对比 📊应用场景 🌐自然语言处理对话系统内容生成 ✍️ &am…

【React】全面解析:从基础知识到高级应用,掌握现代Web开发利器

文章目录 一、React 的基础知识1. 什么是 React?2. React 的基本概念3. 基本示例 二、React 的进阶概念1. 状态(State)和属性(Props)2. 生命周期方法(Lifecycle Methods)3. 钩子(Hoo…

【SQL 新手教程 4/20】关系模型 --索引

💗 关系数据库建立在关系模型上⭐ 关系模型本质上就是若干个存储数据的二维表 记录 (Record): 表的每一行称为记录(Record),记录是一个逻辑意义上的数据 字段 (Column):表的每一列称为字段(Colu…

spring IOC DI -- IOC详解

T04BF 👋专栏: 算法|JAVA|MySQL|C语言 🫵 今天你敲代码了吗 文章目录 4.2 Ioc 详解4.2.1 Bean的存储Controller(控制器存储)Service (服务存储)Repository(仓库存储)Component(组件存储)Configuration(配置存储) 4.2.2 为什么需要这么多类注解?4.2.3方法…

Mac安装Hoomebrew与升级Python版本

参考 mac 安装HomeBrew(100%成功)_mac安装homebrew-CSDN博客 /bin/zsh -c "$(curl -fsSL https://gitee.com/cunkai/HomebrewCN/raw/master/Homebrew.sh)" 安装了Python 3.x版本,你可以使用以下命令来设置默认的Python版本: # 首先找到新安…