【Attention】SKAttention

embedded/2025/3/31 11:13:03/

SKAttention选择核注意力

标题:SKAttention

期刊:IEEE2019

代码: https://github.com/implus/SKNet

简介:

  • 动机:增大感受野来提升性能、多尺度信息聚合方式
  • 解决的问题:自适应调整感受野大小
  • 创新性:提出选择性内核(SK)卷积softmax来进行自适应选择

模型结构

在这里插入图片描述

模型代码

python">import numpy as np
import torch
from torch import nn
from torch.nn import init
from collections import OrderedDict# Selective Kernel Attention
class SKAttention(nn.Module):def __init__(self, channel=512, kernels=[1, 3, 5, 7], reduction=16, group=1, L=32):super().__init__()# 中间维度d的计算self.d = max(L, channel // reduction)# 多分支卷积层(使用不同尺寸的卷积核)self.convs = nn.ModuleList([])for k in kernels:self.convs.append(nn.Sequential(OrderedDict([# 分组卷积(输入输出通道数相同,保持维度)('conv', nn.Conv2d(channel, channel, kernel_size=k, padding=k // 2, groups=group)),# 批归一化(保持维度)  ('bn', nn.BatchNorm2d(channel)),# ReLU激活函数('relu', nn.ReLU())])))# # 通道压缩层(全连接层)self.fc = nn.Linear(channel, self.d)# 多分支注意力权重生成层self.fcs = nn.ModuleList([])for i in range(len(kernels)):self.fcs.append(nn.Linear(self.d, channel))# 注意力权重归一化(沿分支维度softmax)self.softmax = nn.Softmax(dim=0)def forward(self, x):# 输入x形状: [B, C, H, W]bs, c, _, _ = x.size() # 获取输入的batch_size, 通道数, 高度, 宽度conv_outs = []### Split阶段:多分支特征提取for conv in self.convs:conv_outs.append(conv(x)) # 每个分支输出: [B, C, H, W]feats = torch.stack(conv_outs, 0)  # 堆叠后形状: [K, B, C, H, W](K是kernel数量)### Fuse阶段:特征融合U = sum(conv_outs) # 逐元素相加 → [B, C, H, W]### Channel Reduction:通道压缩S = U.mean(-1).mean(-1)  # 空间全局平均池化 → [B, C,1,1]Z = self.fc(S)   # 全连接层降维 → [B, d](d=self.d)### 计算注意力权重weights = []for fc in self.fcs: #  每个kernel对应一个全连接层weight = fc(Z) # 全连接层输出 → [B, C]weights.append(weight.view(bs, c, 1, 1))  # 调整形状 → [B, C, 1, 1]attention_weughts = torch.stack(weights, 0)   # 堆叠 → [K, B, C, 1, 1]attention_weughts = self.softmax(attention_weughts)  # 沿K维度softmax归一化### fuseV = (attention_weughts * feats).sum(0) # 加权求和 → [B, C, H, W]return Vif __name__ == '__main__':input = torch.rand(1,64,256,256).cuda()model = SKAttention(channel=64, reduction=8).cuda()output = model (input)print('input_size:', input.size())print('output_size:', output.size())print("最大内存占用:", torch.cuda.max_memory_allocated() // 1024 // 1024, "MB")

http://www.ppmy.cn/embedded/176029.html

相关文章

优化 SQL 语句方向和提升性能技巧

优化 SQL 语句是提升 MySQL 性能的关键步骤之一。通过优化 SQL 语句,可以减少查询时间、降低服务器负载、提高系统吞吐量。以下是优化 SQL 语句的方法、策略和技巧: 一、优化 SQL 语句的方法 1. 使用 EXPLAIN 分析查询 作用:查看 SQL 语句的执行计划,了解查询是如何执行的…

鱼书--学习2

6. 与学习相关的技巧 6.1 参数的更新 (1) SGD的缺点:SGD低效的根本原因是,梯度的方向并没有指向最小值的方向 基于SGD的最优化的更新路径:呈“之”字形朝最小值(0, 0)移动,效率低 (2&#x…

xQueueGenericReceive中文释义及调用

xQueueGenericReceive 是 FreeRTOS 中的一个内部函数,用于从队列中接收数据。它通常不会被用户直接调用,而是通过 FreeRTOS 提供的 API 函数(如 xQueueReceive 或 xQueuePeek)间接调用。以下是对 xQueueGenericReceive 的详细说明…

@Validated 使用介绍

说明:在项目开发中,请求进入系统的第一步就是校验,在前后端分离的项目中,有前端校验、后端校验。对于后端开发程序员来说,完全依靠前端校验是不合理的,因为只需要用户知道一点计算机知识,就能使…

【LeetCode 热题100】 22. 括号生成 的算法思路及python代码

22. 括号生成 数字 n n n 代表生成括号的对数,请你设计一个函数,用于能够生成所有可能的并且有效的括号组合。 示例 1: 输入:n 3 输出:["((()))","(()())","(())()","()(())&…

题单:精挑细选

题目描述 小王是公司的仓库管理员,一天,他接到了这样一个任务:从仓库中找出一根钢管。这听起来不算什么,但是这根钢管的要求可真是让他犯难了,要求如下: 1.1. 这根钢管一定要是仓库中最长的; …

【AVRCP】深度剖析 AVRCP 中 Generic Access Profile 的要求与应用

目录 一、GAP基础架构与核心要求 1.1 GAP在蓝牙体系中的定位 1.2 核心模式定义 二、AVRCP对GAP的增强要求 2.1 模式扩展规范 2.2 空闲模式过程支持 三、安全机制实现细节 3.1 认证与加密流程 3.2 安全模式要求 四、设备发现与连接建立 4.1 发现过程状态机 4.2 连接…

dify创建第一个Agent

1、首先LLM模型必须支持 Function Calling 由于deepseek-R1本地化部署时还不支持,所以使用 qwq模型。 2、创建空白 Agent 3、为Agent添加工具 4、测试 当未添加时间工具时 询问 时间 如下 5、开启时间工具 询问如下