SCTNet: 单分支 CNN 与 Transformer 语义信息用于实时分割

ops/2025/3/4 13:31:43/

📑 SCTNet: 单分支 CNN 与 Transformer 语义信息用于实时分割

在这里插入图片描述
在这里插入图片描述


1. 摘要翻译

近年来,许多实时语义分割方法采用额外的语义分支来获取丰富的长距离上下文信息。然而,这种额外分支带来了额外的计算开销,降低了推理速度。为了解决这一问题,本文提出了 SCTNet,一种融合 Transformer 语义信息的单分支 CNN,用于实时分割。SCTNet 结合了无推理负担的 Transformer 语义分支轻量级 CNN 单分支的高效性

该方法在训练时利用 Transformer 作为仅训练时可用的语义分支,借助CFBlock(ConvFormer Block)语义信息对齐模块(SIAM),实现从 Transformer 分支向 CNN 语义信息的高效迁移。在推理阶段,仅需部署单一 CNN 分支,从而保持高效推理能力。

实验表明,SCTNet 在 Cityscapes、ADE20K 和 COCO-Stuff10K 数据集上达到了新的最先进(state-of-the-art)性能,既提升了分割精度,又实现了高推理速度。📊 SCTNet 代码与模型地址:GitHub - SCTNet

[论文英文原名称]: SCTNet: Single-Branch CNN with Transformer Semantic Information for Real-Time Segmentation
[论文中文名称]: SCTNet: 结合 Transformer 语义信息的单分支 CNN 进行实时分割
[论文链接]: 2312.17071v2.pdf


2. 问题背景

2.1 语义分割的重要性

语义分割是一项计算机视觉的基础任务,目标是为图像中的每个像素分配语义类别标签。它广泛应用于自动驾驶🚗、医学影像分析🩺、移动应用📱等领域。

2.2 现有方法的局限性

目前的语义分割方法倾向于增加上下文信息来提升精度,常见的方法包括:

  • 大感受野(DeepLab系列)
  • 多尺度特征融合(PSPNet、U-Net)
  • 自注意力机制(Self-Attention)(Transformer-based 方法)

尽管这些方法显著提升了语义分割的性能,但它们通常导致 高计算成本,尤其是基于 Transformer 的方法计算复杂度往往随图像分辨率呈平方增长,严重影响推理速度❌。


3. 核心概念

3.1 SCTNet 方法概述

SCTNet 提出了新颖的单分支网络结构,结合了 Transformer 长距离上下文感知能力与 CNN 高效推理特性,从而实现两者的优势互补。

主要创新点:

  1. 训练阶段: 采用 Transformer 语义分支 来增强语义信息提取能力。
  2. 推理阶段: 仅保留单分支 CNN,实现高效实时推理。
  3. 核心组件:
    • CFBlock(ConvFormer Block): 用卷积模拟 Transformer 的长程建模能力。
    • SIAM(语义信息对齐模块): 解决 CNN 和 Transformer 语义信息的不匹配问题。

📌 示意图:
在这里插入图片描述

  • SCTNet 速度-精度对比图(见 图1 📊)。

在这里插入图片描述

  • 不同网络架构的对比(见 图2 🏗️)。

4. 核心模块的操作步骤

在这里插入图片描述

4.1 训练阶段

1. 采用 Transformer 语义分支

  • 目标:学习长距离上下文信息,提高全局语义理解能力。
  • 训练时,将 Transformer 分支 作为语义信息提取器。

2. 语义信息对齐

  • 使用 CFBlock 使 CNN 具有 Transformer 类似的上下文感知能力。
  • 利用 SIAM 进行特征对齐,减小 CNN 和 Transformer 之间的语义鸿沟。

3. 共享解码头

  • 使 Transformer 语义特征在训练期间能够更好地迁移到 CNN。

4.2 推理阶段

1. 仅使用 CNN 进行推理

  • SCTNet 在推理时仅保留单分支 CNN,无需 Transformer 语义分支,保证了 推理速度最快
  • 在 Cityscapes 数据集上,SCTNet 以 更低的计算量 达到了最先进水平(见 图1 📈)。

5. 文章贡献

本文提出的 SCTNet 主要贡献如下:

  1. 提出 SCTNet 架构
    • 兼具 Transformer 语义提取能力 和 CNN 高效推理能力
  2. 创新性 CFBlock 设计
    • 仅使用 卷积运算 模拟 Transformer 长距离建模能力
  3. 语义信息对齐模块 SIAM
    • 对齐 CNN 和 Transformer 语义特征,确保 CNN 在推理时仍能保持高语义表达能力
  4. 提升实时分割性能
    • Cityscapes、ADE20K 和 COCO-Stuff-10K 数据集上实现 新的最先进水平(见 图1 📊)。

6. 实验结果与应用

6.1 主要实验

  • 在 Cityscapes 数据集上的实验结果
    • SCTNet 以 更快的推理速度(>140 FPS)达到 最优精度(见 图1 📉)。
  • 在 ADE20K 和 COCO-Stuff10K 上的实验
    • 进一步验证了 SCTNet 优越的泛化能力

6.2 实际应用

  1. 自动驾驶
    • 需要高精度、低延迟的分割算法,SCTNet 适用于此任务。
  2. 移动端应用
    • SCTNet 由于 单分支 CNN 结构,适用于 轻量化推理场景

7. 对未来工作的启示

7.1 未来优化方向

  • 提升语义信息提取
    • 研究更高效的 CFBlock 设计,增强 CNN 语义感知能力。
  • 低计算量 Transformer 设计
    • 未来可以设计 更轻量级的 Transformer 结构,进一步减少计算量(见 图2 🏗️)。

7.2 可能的研究方向

  • 扩展 SCTNet 到 3D 语义分割
    • 点云数据(如激光雷达) 任务上测试 SCTNet 的适用性。
  • 结合自适应神经架构搜索(NAS)
    • 通过 NAS 自动优化 SCTNet 的结构,寻找更优的速度-精度平衡点。

8. 核心模块代码

import torch
from torch import nn
import torch.nn.functional as F
from mmengine.model import constant_init, kaiming_init, trunc_normal_init, normal_init
from timm.models.layers import DropPathclass MLP(nn.Module):def __init__(self, in_channels, hidden_channels=None, out_channels=None, drop_rate=0.0):super(MLP, self).__init__()hidden_channels = hidden_channels or in_channelsout_channels = out_channels or in_channelsself.norm = nn.BatchNorm2d(in_channels, eps=1e-06)self.conv1 = nn.Conv2d(in_channels, hidden_channels, 3, 1, 1)self.act = nn.GELU()self.conv2 = nn.Conv2d(hidden_channels, out_channels, 3, 1, 1)self.drop = nn.Dropout(drop_rate)self.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, nn.Linear):trunc_normal_init(m.weight, std=0.02)if m.bias is not None:constant_init(m.bias, val=0)elif isinstance(m, (nn.SyncBatchNorm, nn.BatchNorm2d)):constant_init(m.weight, val=1.0)constant_init(m.bias, val=0)elif isinstance(m, nn.Conv2d):kaiming_init(m.weight)if m.bias is not None:constant_init(m.bias, val=0)def forward(self, x):x = self.norm(x)x = self.conv1(x)x = self.act(x)x = self.drop(x)x = self.conv2(x)x = self.drop(x)return xclass ConvolutionalAttention(nn.Module):"""The ConvolutionalAttention implementationArgs:in_channels (int, optional): The input channels.inter_channels (int, optional): The channels of intermediate feature.out_channels (int, optional): The output channels.num_heads (int, optional): The num of heads in attention. Default: 8"""def __init__(self, in_channels, out_channels, inter_channels, num_heads=8):super(ConvolutionalAttention, self).__init__()assert (out_channels % num_heads == 0), "out_channels ({}) should be be a multiple of num_heads ({})".format(out_channels, num_heads)self.in_channels = in_channelsself.out_channels = out_channelsself.inter_channels = inter_channelsself.num_heads = num_headsself.norm = nn.BatchNorm2d(in_channels, eps=1e-06)self.kv = nn.Parameter(torch.zeros(inter_channels, in_channels, 7, 1))self.kv3 = nn.Parameter(torch.zeros(inter_channels, in_channels, 1, 7))trunc_normal_init(self.kv, std=0.001)trunc_normal_init(self.kv3, std=0.001)self.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, nn.Linear):trunc_normal_init(m.weight, std=0.001)if m.bias is not None:constant_init(m.bias, val=0.0)elif isinstance(m, (nn.SyncBatchNorm, nn.BatchNorm2d)):constant_init(m.weight, val=1.0)constant_init(m.bias, val=0.0)elif isinstance(m, nn.Conv2d):trunc_normal_init(m.weight, std=0.001)if m.bias is not None:constant_init(m.bias, val=0.0)def _act_dn(self, x):x_shape = x.shape  # n,c_inter,h,wh, w = x_shape[2], x_shape[3]x = x.reshape([x_shape[0], self.num_heads, self.inter_channels // self.num_heads, -1])  # n,c_inter,h,w -> n,heads,c_inner//heads,hwx = F.softmax(x, dim=3)x = x / (torch.sum(x, dim=2, keepdim=True) + 1e-06)x = x.reshape([x_shape[0], self.inter_channels, h, w])return xdef forward(self, x):"""Args:x (Tensor): The input tensor. (n,c,h,w)cross_k (Tensor, optional): The dims is (n*144, c_in, 1, 1)cross_v (Tensor, optional): The dims is (n*c_in, 144, 1, 1)"""x = self.norm(x)x1 = F.conv2d(x, self.kv, bias=None, stride=1, padding=(3, 0))x1 = self._act_dn(x1)x1 = F.conv2d(x1, self.kv.transpose(1, 0), bias=None, stride=1, padding=(3, 0))x3 = F.conv2d(x, self.kv3, bias=None, stride=1, padding=(0, 3))x3 = self._act_dn(x3)x3 = F.conv2d(x3, self.kv3.transpose(1, 0), bias=None, stride=1, padding=(0, 3))x = x1 + x3return xclass CFBlock(nn.Module):"""The CFBlock implementation based on PaddlePaddle.Args:in_channels (int, optional): The input channels.out_channels (int, optional): The output channels.num_heads (int, optional): The num of heads in attention. Default: 8drop_rate (float, optional): The drop rate in MLP. Default:0.drop_path_rate (float, optional): The drop path rate in CFBlock. Default: 0.2"""def __init__(self, in_channels, out_channels, num_heads=8, drop_rate=0.0, drop_path_rate=0.0):super(CFBlock, self).__init__()in_channels_l = in_channelsout_channels_l = out_channelsself.attn_l = ConvolutionalAttention(in_channels_l, out_channels_l, inter_channels=64, num_heads=num_heads)self.mlp_l = MLP(out_channels_l, drop_rate=drop_rate)self.drop_path = (DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity())def _init_weights_kaiming(self, m):if isinstance(m, nn.Linear):trunc_normal_init(m.weight, std=0.02)if m.bias is not None:constant_init(m.bias, val=0)elif isinstance(m, (nn.SyncBatchNorm, nn.BatchNorm2d)):constant_init(m.weight, val=1.0)constant_init(m.bias, val=0)elif isinstance(m, nn.Conv2d):kaiming_init(m.weight)if m.bias is not None:constant_init(m.bias, val=0)def forward(self, x):x_res = xx = x_res + self.drop_path(self.attn_l(x))x = x + self.drop_path(self.mlp_l(x))return xif __name__ == "__main__":device = torch.device("cuda" if torch.cuda.is_available() else "cpu")input = torch.randn(1, 32, 256, 256).to(device)print(input.shape)cfb = CFBlock(32, 32).to(device)output = cfb(input)print(output.shape)

总结

本文详细介绍了 SCTNet 的创新方法、实验结果及未来研究方向,希望能帮助研究人员进一步理解 SCTNet 在实时语义分割中的应用潜力。如果你对该方法感兴趣,可以查看论文详情:2312.17071v2.pdf 🚀


http://www.ppmy.cn/ops/163054.html

相关文章

linux常见操作命令

在Linux系统(如CentOS)中,有很多常见且实用的操作命令,以下为你分类介绍: 1. 文件和目录操作: - ls :列出目录内容,如 ls -l (长格式显示)、 ls -a &#x…

基于 Ingress-Nginx 实现 mTLS 双向认证

目录 背景描述: TLS 和 MTLS 之间的差异 通过自签名证书启用双向 TLS 1. 生成证书 (1) 生成 CA(根证书颁发机构) (2) 生成 CA(根证书颁发机构) (3) 生成客户端证书 2. 在 Kubernetes 中配置 mTLS &#x…

校园快递平台系统(小程序论文源码调试讲解)

第4章 系统设计 用户对着浏览器操作,肯定会出现某些不可预料的问题,但是不代表着系统对于用户在浏览器上的操作不进行处理,所以说,要提前考虑可能会出现的问题。 4.1 系统设计思想 系统设计,肯定要把设计的思想进行统…

【开源-常用C/C++命令行解析库对比】

以下是几种常用的C/C命令行解析库的对比表格,以及它们的GitHub开源库地址: 库名称语言特点是否支持子命令是否支持配置文件是否支持自动生成帮助信息GitHub地址ClaraC11及以上单一头文件,轻量级,非异常错误处理,自动类…

HTML元素,标签到底指的哪块部分?单双标签何时使用?

1. 标签&#xff08;Tag&#xff09; vs 元素&#xff08;Element&#xff09; 标签&#xff08;Tag&#xff09; 标签是 HTML 中用于定义元素的符号&#xff0c;用尖括号 < > 包裹。例如 <img> 是标签。元素&#xff08;Element&#xff09; 元素是由 标签 内容…

芯麦GC1277:电脑散热风扇驱动芯片的优质之选 并可替代传统的0CH477/灿瑞芯片。

在电脑散热风扇、小型电机驱动等场景中&#xff0c;驱动芯片的选型直接影响系统效率、噪音控制及长期可靠性。灿瑞的0CH477曾是市场主流方案&#xff0c;但随着国产芯片技术的成熟&#xff0c;芯麦半导体推出的GC1277凭借更优的驱动性能、智能化保护机制及成本优势&#xff0c;…

【uniapp原生】实时记录接口请求延迟,并生成写入文件到安卓设备

在开发实时数据监控应用时&#xff0c;记录接口请求的延迟对于性能分析和用户体验优化至关重要。本文将基于 UniApp 框架&#xff0c;介绍如何实现一个实时记录接口请求延迟的功能&#xff0c;并深入解析相关代码的实现细节。 前期准备&必要的理解 1. 功能概述 该功能的…

git的恢复命令

右键查看 找到版本的 提交ID git reset --soft c097b534188163194fa0e00a20d9e0f07ad82549