【即插即用】SGE注意力机制(附源码)

devtools/2024/9/23 14:29:57/
原文链接:

https://arxiv.org/abs/1905.09646

源码链接:

https://github.com/implus/PytorchInsight

摘要简介:

在图像识别领域,卷积神经网络(CNN)通过收集和整合复杂对象的层次化和不同部分的语义子特征来生成特征表示。这些子特征通常以分组的形式分布在每一层的特征向量中,代表不同的语义实体。然而,这些子特征的激活常常受到相似模式和噪声背景的空间影响,从而可能导致定位和识别的错误。

为了解决这一问题,研究者们提出了一种名为空间分组增强(SGE)的模块。SGE模块可以为每个语义组中的每个空间位置生成一个注意力因子,以调整每个子特征的重要性。通过这种方式,每个单独的组都能够自主地增强其学习到的表达,并抑制可能的噪声。这些注意力因子仅由组内全局和局部特征描述符之间的相似性来指导,因此SGE模块的设计非常轻量级,几乎不需要额外的参数和计算。

尽管SGE组件仅通过类别监督进行训练,但它在突出显示具有各种高阶语义的多个活跃区域方面表现出色,如狗的眼睛、鼻子等。当与流行的CNN骨干网络结合使用时,SGE能够显著提高图像识别任务的性能。具体来说,基于ResNet50骨干网络,SGE在ImageNet基准测试中实现了1.2%的Top-1准确率提升,并在广泛的检测器(Faster/Mask/Cascade RCNN和RetinaNet)上,于COCO基准测试中获得了1.0∼2.0%的AP增益。相关的代码和预训练模型已经公开可用。


模型结构:

Pytorch版源码:

# Spatial Group-wise Enhance主要是用在语义分割上,所以在检测上的效果一般,没有带来多少提升
import torch
from torch import nn
from torch.nn import initclass SpatialGroupEnhance(nn.Module):def __init__(self, groups):super().__init__()self.groups = groupsself.avg_pool = nn.AdaptiveAvgPool2d(1)self.weight = nn.Parameter(torch.zeros(1, groups, 1, 1))self.bias = nn.Parameter(torch.zeros(1, groups, 1, 1))self.sig = nn.Sigmoid()self.init_weights()def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, x):b, c, h, w = x.shapex = x.view(b * self.groups, -1, h, w)  # bs*g,dim//g,h,wxn = x * self.avg_pool(x)  # bs*g,dim//g,h,wxn = xn.sum(dim=1, keepdim=True)  # bs*g,1,h,wt = xn.view(b * self.groups, -1)  # bs*g,h*wt = t - t.mean(dim=1, keepdim=True)  # bs*g,h*wstd = t.std(dim=1, keepdim=True) + 1e-5t = t / std  # bs*g,h*wt = t.view(b, self.groups, h, w)  # bs,g,h*wt = t * self.weight + self.bias  # bs,g,h*wt = t.view(b * self.groups, 1, h, w)  # bs*g,1,h*wx = x * self.sig(t)x = x.view(b, c, h, w)return xif __name__ == '__main__':input = torch.randn(2, 32, 512, 512)SGE = SpatialGroupEnhance(groups=input.size(1))output = SGE(input)print(output.shape)


http://www.ppmy.cn/devtools/25589.html

相关文章

bitbake ERROR:No space left on device or exceeds fs.inotify.max_user_watches?

使用vscode remote ssh来编辑服务器上源码后,执行bitbake编译时,遇到了如下报错:ERROR:No space left on device or exceeds fs.inotify.max_user_watches? 可能的解决方法和原因: 首先可以尝试关闭vscode,然后在服务…

给rwkv_pytorch增加rag

RAG 参考地址语义模型地址选择该模型使用方法方法二安装方法下载模型到本地材料材料处理语义分割计算得分根据得分 分割文本 构建向量数据库问答匹配问答整合 参考地址 RAG简单教程 分割策略 语义模型地址 hf 选择该模型 gte 使用方法 import torch.nn.functional as F…

Linux下网络编程-基于多任务的简易并发服务器

Linux下网络编程-基于多任务的简易并发服务器 #include <stdio.h> #include <stdlib.h> #include <string.h> #include <unistd.h> #include <signal.h> #include <sys/wait.h> #include <arpa/inet.h> #include <sys/socket.h&…

科技赋能无人零售

科技赋能无人零售&#xff0c;使其具备以下独特优势&#xff1a; 1. 全天候无缝服务 &#xff1a;无人零售店依托科技&#xff0c;实现24小时不间断运营&#xff0c;不受人力限制&#xff0c;满足消费者随时购物需求&#xff0c;尤其惠及夜间工作者、夜猫子及急需购物者&…

汽车信息安全入门总结(1)

目录 1.汽车信息安全应关注什么 2.法规先行 3.小结 1.汽车信息安全应关注什么 汽车信息安全从2015年开始被引起重视发展至今已近10年时间&#xff0c;虽然有很多高屋建瓴的白皮书、指导标准可以指导我们从宏观了解汽车信息安全这个新兴行业&#xff0c;但真正实际需求落实到…

Qt窗口全屏显示方法

要在Qt中设置窗口全屏显示&#xff0c;可以采取以下方法&#xff1a; 使用showFullScreen()方法&#xff1a; 对于QWidget对象&#xff0c;可以直接调用showFullScreen()方法来实现全屏显示。 QWidget w; w.showFullScreen();使用setWindowState()方法&#xff1a; 可以通过…

【TCP:可靠数据传输,快速重传,流量控制,TCP流量控制】

文章目录 可靠数据传输TCP&#xff1a;可靠数据传输TCP发送方事件快速重传流量控制TCP流量控制 可靠数据传输 TCP&#xff1a;可靠数据传输 TCP在IP不可靠服务的基础上建立了rdt 管道化的报文段 GBN or SR 累计确认&#xff08;像GBN&#xff09;单个重传定时器&#xff08;像…

神之浩劫2测试资格100%获取教程 测试资格获取方法教程

《神之浩劫》是一款基于Unreal 3&#xff08;虚幻3&#xff09;游戏引擎开发的3D团队竞技游戏&#xff0c;由美国Hi-Rez工作室开发、腾讯全球代理。2013年10月31日&#xff0c;游戏开启国服首测&#xff0c;并于2014年3月25日在美国公测。2018年1月20日&#xff0c;国服并入全球…