📑 SCTNet: 单分支 CNN 与 Transformer 语义信息用于实时分割
1. 摘要翻译
近年来,许多实时语义分割方法采用额外的语义分支来获取丰富的长距离上下文信息。然而,这种额外分支带来了额外的计算开销,降低了推理速度。为了解决这一问题,本文提出了 SCTNet,一种融合 Transformer 语义信息的单分支 CNN,用于实时分割。SCTNet 结合了无推理负担的 Transformer 语义分支和轻量级 CNN 单分支的高效性。
该方法在训练时利用 Transformer 作为仅训练时可用的语义分支,借助CFBlock(ConvFormer Block)和语义信息对齐模块(SIAM),实现从 Transformer 分支向 CNN 语义信息的高效迁移。在推理阶段,仅需部署单一 CNN 分支,从而保持高效推理能力。
实验表明,SCTNet 在 Cityscapes、ADE20K 和 COCO-Stuff10K 数据集上达到了新的最先进(state-of-the-art)性能,既提升了分割精度,又实现了高推理速度。📊 SCTNet 代码与模型地址:GitHub - SCTNet
[论文英文原名称]: SCTNet: Single-Branch CNN with Transformer Semantic Information for Real-Time Segmentation
[论文中文名称]: SCTNet: 结合 Transformer 语义信息的单分支 CNN 进行实时分割
[论文链接]: 2312.17071v2.pdf
2. 问题背景
2.1 语义分割的重要性
语义分割是一项计算机视觉的基础任务,目标是为图像中的每个像素分配语义类别标签。它广泛应用于自动驾驶🚗、医学影像分析🩺、移动应用📱等领域。
2.2 现有方法的局限性
目前的语义分割方法倾向于增加上下文信息来提升精度,常见的方法包括:
- 大感受野(DeepLab系列)
- 多尺度特征融合(PSPNet、U-Net)
- 自注意力机制(Self-Attention)(Transformer-based 方法)
尽管这些方法显著提升了语义分割的性能,但它们通常导致 高计算成本,尤其是基于 Transformer 的方法计算复杂度往往随图像分辨率呈平方增长,严重影响推理速度❌。
3. 核心概念
3.1 SCTNet 方法概述
SCTNet 提出了新颖的单分支网络结构,结合了 Transformer 长距离上下文感知能力与 CNN 高效推理特性,从而实现两者的优势互补。
主要创新点:
- 训练阶段: 采用 Transformer 语义分支 来增强语义信息提取能力。
- 推理阶段: 仅保留单分支 CNN,实现高效实时推理。
- 核心组件:
- CFBlock(ConvFormer Block): 用卷积模拟 Transformer 的长程建模能力。
- SIAM(语义信息对齐模块): 解决 CNN 和 Transformer 语义信息的不匹配问题。
📌 示意图:
- SCTNet 速度-精度对比图(见 图1 📊)。
- 不同网络架构的对比(见 图2 🏗️)。
4. 核心模块的操作步骤
4.1 训练阶段
1. 采用 Transformer 语义分支
- 目标:学习长距离上下文信息,提高全局语义理解能力。
- 训练时,将 Transformer 分支 作为语义信息提取器。
2. 语义信息对齐
- 使用 CFBlock 使 CNN 具有 Transformer 类似的上下文感知能力。
- 利用 SIAM 进行特征对齐,减小 CNN 和 Transformer 之间的语义鸿沟。
3. 共享解码头
- 使 Transformer 语义特征在训练期间能够更好地迁移到 CNN。
4.2 推理阶段
1. 仅使用 CNN 进行推理
- SCTNet 在推理时仅保留单分支 CNN,无需 Transformer 语义分支,保证了 推理速度最快。
- 在 Cityscapes 数据集上,SCTNet 以 更低的计算量 达到了最先进水平(见 图1 📈)。
5. 文章贡献
本文提出的 SCTNet 主要贡献如下:
- 提出 SCTNet 架构
- 兼具 Transformer 语义提取能力 和 CNN 高效推理能力。
- 创新性 CFBlock 设计
- 仅使用 卷积运算 模拟 Transformer 长距离建模能力。
- 语义信息对齐模块 SIAM
- 对齐 CNN 和 Transformer 语义特征,确保 CNN 在推理时仍能保持高语义表达能力。
- 提升实时分割性能
- 在 Cityscapes、ADE20K 和 COCO-Stuff-10K 数据集上实现 新的最先进水平(见 图1 📊)。
6. 实验结果与应用
6.1 主要实验
- 在 Cityscapes 数据集上的实验结果
- SCTNet 以 更快的推理速度(>140 FPS)达到 最优精度(见 图1 📉)。
- 在 ADE20K 和 COCO-Stuff10K 上的实验
- 进一步验证了 SCTNet 优越的泛化能力。
6.2 实际应用
- 自动驾驶
- 需要高精度、低延迟的分割算法,SCTNet 适用于此任务。
- 移动端应用
- SCTNet 由于 单分支 CNN 结构,适用于 轻量化推理场景。
7. 对未来工作的启示
7.1 未来优化方向
- 提升语义信息提取
- 研究更高效的 CFBlock 设计,增强 CNN 语义感知能力。
- 低计算量 Transformer 设计
- 未来可以设计 更轻量级的 Transformer 结构,进一步减少计算量(见 图2 🏗️)。
7.2 可能的研究方向
- 扩展 SCTNet 到 3D 语义分割
- 在 点云数据(如激光雷达) 任务上测试 SCTNet 的适用性。
- 结合自适应神经架构搜索(NAS)
- 通过 NAS 自动优化 SCTNet 的结构,寻找更优的速度-精度平衡点。
8. 核心模块代码
import torch
from torch import nn
import torch.nn.functional as F
from mmengine.model import constant_init, kaiming_init, trunc_normal_init, normal_init
from timm.models.layers import DropPathclass MLP(nn.Module):def __init__(self, in_channels, hidden_channels=None, out_channels=None, drop_rate=0.0):super(MLP, self).__init__()hidden_channels = hidden_channels or in_channelsout_channels = out_channels or in_channelsself.norm = nn.BatchNorm2d(in_channels, eps=1e-06)self.conv1 = nn.Conv2d(in_channels, hidden_channels, 3, 1, 1)self.act = nn.GELU()self.conv2 = nn.Conv2d(hidden_channels, out_channels, 3, 1, 1)self.drop = nn.Dropout(drop_rate)self.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, nn.Linear):trunc_normal_init(m.weight, std=0.02)if m.bias is not None:constant_init(m.bias, val=0)elif isinstance(m, (nn.SyncBatchNorm, nn.BatchNorm2d)):constant_init(m.weight, val=1.0)constant_init(m.bias, val=0)elif isinstance(m, nn.Conv2d):kaiming_init(m.weight)if m.bias is not None:constant_init(m.bias, val=0)def forward(self, x):x = self.norm(x)x = self.conv1(x)x = self.act(x)x = self.drop(x)x = self.conv2(x)x = self.drop(x)return xclass ConvolutionalAttention(nn.Module):"""The ConvolutionalAttention implementationArgs:in_channels (int, optional): The input channels.inter_channels (int, optional): The channels of intermediate feature.out_channels (int, optional): The output channels.num_heads (int, optional): The num of heads in attention. Default: 8"""def __init__(self, in_channels, out_channels, inter_channels, num_heads=8):super(ConvolutionalAttention, self).__init__()assert (out_channels % num_heads == 0), "out_channels ({}) should be be a multiple of num_heads ({})".format(out_channels, num_heads)self.in_channels = in_channelsself.out_channels = out_channelsself.inter_channels = inter_channelsself.num_heads = num_headsself.norm = nn.BatchNorm2d(in_channels, eps=1e-06)self.kv = nn.Parameter(torch.zeros(inter_channels, in_channels, 7, 1))self.kv3 = nn.Parameter(torch.zeros(inter_channels, in_channels, 1, 7))trunc_normal_init(self.kv, std=0.001)trunc_normal_init(self.kv3, std=0.001)self.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, nn.Linear):trunc_normal_init(m.weight, std=0.001)if m.bias is not None:constant_init(m.bias, val=0.0)elif isinstance(m, (nn.SyncBatchNorm, nn.BatchNorm2d)):constant_init(m.weight, val=1.0)constant_init(m.bias, val=0.0)elif isinstance(m, nn.Conv2d):trunc_normal_init(m.weight, std=0.001)if m.bias is not None:constant_init(m.bias, val=0.0)def _act_dn(self, x):x_shape = x.shape # n,c_inter,h,wh, w = x_shape[2], x_shape[3]x = x.reshape([x_shape[0], self.num_heads, self.inter_channels // self.num_heads, -1]) # n,c_inter,h,w -> n,heads,c_inner//heads,hwx = F.softmax(x, dim=3)x = x / (torch.sum(x, dim=2, keepdim=True) + 1e-06)x = x.reshape([x_shape[0], self.inter_channels, h, w])return xdef forward(self, x):"""Args:x (Tensor): The input tensor. (n,c,h,w)cross_k (Tensor, optional): The dims is (n*144, c_in, 1, 1)cross_v (Tensor, optional): The dims is (n*c_in, 144, 1, 1)"""x = self.norm(x)x1 = F.conv2d(x, self.kv, bias=None, stride=1, padding=(3, 0))x1 = self._act_dn(x1)x1 = F.conv2d(x1, self.kv.transpose(1, 0), bias=None, stride=1, padding=(3, 0))x3 = F.conv2d(x, self.kv3, bias=None, stride=1, padding=(0, 3))x3 = self._act_dn(x3)x3 = F.conv2d(x3, self.kv3.transpose(1, 0), bias=None, stride=1, padding=(0, 3))x = x1 + x3return xclass CFBlock(nn.Module):"""The CFBlock implementation based on PaddlePaddle.Args:in_channels (int, optional): The input channels.out_channels (int, optional): The output channels.num_heads (int, optional): The num of heads in attention. Default: 8drop_rate (float, optional): The drop rate in MLP. Default:0.drop_path_rate (float, optional): The drop path rate in CFBlock. Default: 0.2"""def __init__(self, in_channels, out_channels, num_heads=8, drop_rate=0.0, drop_path_rate=0.0):super(CFBlock, self).__init__()in_channels_l = in_channelsout_channels_l = out_channelsself.attn_l = ConvolutionalAttention(in_channels_l, out_channels_l, inter_channels=64, num_heads=num_heads)self.mlp_l = MLP(out_channels_l, drop_rate=drop_rate)self.drop_path = (DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity())def _init_weights_kaiming(self, m):if isinstance(m, nn.Linear):trunc_normal_init(m.weight, std=0.02)if m.bias is not None:constant_init(m.bias, val=0)elif isinstance(m, (nn.SyncBatchNorm, nn.BatchNorm2d)):constant_init(m.weight, val=1.0)constant_init(m.bias, val=0)elif isinstance(m, nn.Conv2d):kaiming_init(m.weight)if m.bias is not None:constant_init(m.bias, val=0)def forward(self, x):x_res = xx = x_res + self.drop_path(self.attn_l(x))x = x + self.drop_path(self.mlp_l(x))return xif __name__ == "__main__":device = torch.device("cuda" if torch.cuda.is_available() else "cpu")input = torch.randn(1, 32, 256, 256).to(device)print(input.shape)cfb = CFBlock(32, 32).to(device)output = cfb(input)print(output.shape)
总结
本文详细介绍了 SCTNet 的创新方法、实验结果及未来研究方向,希望能帮助研究人员进一步理解 SCTNet 在实时语义分割中的应用潜力。如果你对该方法感兴趣,可以查看论文详情:2312.17071v2.pdf 🚀