原文链接:
https://arxiv.org/abs/1905.09646
源码链接:
https://github.com/implus/PytorchInsight
摘要简介:
在图像识别领域,卷积神经网络(CNN)通过收集和整合复杂对象的层次化和不同部分的语义子特征来生成特征表示。这些子特征通常以分组的形式分布在每一层的特征向量中,代表不同的语义实体。然而,这些子特征的激活常常受到相似模式和噪声背景的空间影响,从而可能导致定位和识别的错误。
为了解决这一问题,研究者们提出了一种名为空间分组增强(SGE)的模块。SGE模块可以为每个语义组中的每个空间位置生成一个注意力因子,以调整每个子特征的重要性。通过这种方式,每个单独的组都能够自主地增强其学习到的表达,并抑制可能的噪声。这些注意力因子仅由组内全局和局部特征描述符之间的相似性来指导,因此SGE模块的设计非常轻量级,几乎不需要额外的参数和计算。
尽管SGE组件仅通过类别监督进行训练,但它在突出显示具有各种高阶语义的多个活跃区域方面表现出色,如狗的眼睛、鼻子等。当与流行的CNN骨干网络结合使用时,SGE能够显著提高图像识别任务的性能。具体来说,基于ResNet50骨干网络,SGE在ImageNet基准测试中实现了1.2%的Top-1准确率提升,并在广泛的检测器(Faster/Mask/Cascade RCNN和RetinaNet)上,于COCO基准测试中获得了1.0∼2.0%的AP增益。相关的代码和预训练模型已经公开可用。
模型结构:
Pytorch版源码:
# Spatial Group-wise Enhance主要是用在语义分割上,所以在检测上的效果一般,没有带来多少提升
import torch
from torch import nn
from torch.nn import initclass SpatialGroupEnhance(nn.Module):def __init__(self, groups):super().__init__()self.groups = groupsself.avg_pool = nn.AdaptiveAvgPool2d(1)self.weight = nn.Parameter(torch.zeros(1, groups, 1, 1))self.bias = nn.Parameter(torch.zeros(1, groups, 1, 1))self.sig = nn.Sigmoid()self.init_weights()def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, x):b, c, h, w = x.shapex = x.view(b * self.groups, -1, h, w) # bs*g,dim//g,h,wxn = x * self.avg_pool(x) # bs*g,dim//g,h,wxn = xn.sum(dim=1, keepdim=True) # bs*g,1,h,wt = xn.view(b * self.groups, -1) # bs*g,h*wt = t - t.mean(dim=1, keepdim=True) # bs*g,h*wstd = t.std(dim=1, keepdim=True) + 1e-5t = t / std # bs*g,h*wt = t.view(b, self.groups, h, w) # bs,g,h*wt = t * self.weight + self.bias # bs,g,h*wt = t.view(b * self.groups, 1, h, w) # bs*g,1,h*wx = x * self.sig(t)x = x.view(b, c, h, w)return xif __name__ == '__main__':input = torch.randn(2, 32, 512, 512)SGE = SpatialGroupEnhance(groups=input.size(1))output = SGE(input)print(output.shape)