(即插即用模块-Attention部分) 二十、(2021) GAA 门控轴向注意力

server/2024/11/30 6:15:43/

在这里插入图片描述

文章目录

  • 1、Gated Axial-Attention
  • 2、代码实现

paper:Medical Transformer: Gated Axial-Attention for Medical Image Segmentation

Code:https://github.com/jeya-maria-jose/Medical-Transformer


1、Gated Axial-Attention

论文首先分析了 ViTs 在训练小规模数据集时的弊端以及指出了 ViTs 的计算复杂度偏高。为此,论文提出了一种门控轴向注意力(Gated Axial-Attention),其通过在自注意力模块中引入额外的门控机制来扩展现有的体系结构。在分析了位置偏差难以学习、相对位置编码不够准确等问题后,通过将可控制的影响位置偏差施加在编码的非本地上下文来实现改进。Gated Axial-Attention的 核心思想是Gate门控机制,通过引入 Gate 控制机制来控制位置编码对 Self-Attention 的影响程度。

对于一个输入特征 X,Gated Axial-Attention的实现过程:

  1. 输入特征图: 将输入图像提取特征图,并进行通道维度上的线性变换,得到 Query、Key 和 Value 向量。

  2. Axial-Attention

    在高度方向上进行 1D Self-Attention,计算像素之间的依赖关系。

    在宽度方向上进行 1D Self-Attention,计算像素之间的依赖关系。

  3. Positional Encoding:计算相对位置编码,将像素位置信息融入到 Query、Key 和 Value 向量中。

  4. Gate 控制机制:通过可学习的 Gate 参数,控制相对位置编码对 Self-Attention 的影响程度。

  5. 输出特征图: 将经过 Self-Attention 和 Gate 控制的特征图进行线性变换,得到最终输出特征图。


Gated Axial-Attention 结构图:
在这里插入图片描述


2、代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import mathdef conv1x1(in_planes, out_planes, stride=1):"""1x1 卷积"""return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)class qkv_transform(nn.Conv1d):"""Conv1d for qkv_transform"""class AxialAttention(nn.Module):def __init__(self, in_planes, out_planes, groups=8, kernel_size=56,stride=1, bias=False, width=False):assert (in_planes % groups == 0) and (out_planes % groups == 0)super(AxialAttention, self).__init__()self.in_planes = in_planesself.out_planes = out_planesself.groups = groupsself.group_planes = out_planes // groupsself.kernel_size = kernel_sizeself.stride = strideself.bias = biasself.width = width# Multi-head self attentionself.qkv_transform = qkv_transform(in_planes, out_planes * 2, kernel_size=1, stride=1,padding=0, bias=False)self.bn_qkv = nn.BatchNorm1d(out_planes * 2)self.bn_similarity = nn.BatchNorm2d(groups * 3)self.bn_output = nn.BatchNorm1d(out_planes * 2)# Position embeddingself.relative = nn.Parameter(torch.randn(self.group_planes * 2, kernel_size * 2 - 1), requires_grad=True)query_index = torch.arange(kernel_size).unsqueeze(0)key_index = torch.arange(kernel_size).unsqueeze(1)relative_index = key_index - query_index + kernel_size - 1self.register_buffer('flatten_index', relative_index.view(-1))if stride > 1:self.pooling = nn.AvgPool2d(stride, stride=stride)self.reset_parameters()def forward(self, x):# pdb.set_trace()if self.width:x = x.permute(0, 2, 1, 3)else:x = x.permute(0, 3, 1, 2)  # N, W, C, HN, W, C, H = x.shapex = x.contiguous().view(N * W, C, H)# Transformationsqkv = self.bn_qkv(self.qkv_transform(x))q, k, v = torch.split(qkv.reshape(N * W, self.groups, self.group_planes * 2, H),[self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=2)# Calculate position embeddingall_embeddings = torch.index_select(self.relative, 1, self.flatten_index).view(self.group_planes * 2,self.kernel_size,self.kernel_size)q_embedding, k_embedding, v_embedding = torch.split(all_embeddings,[self.group_planes // 2, self.group_planes // 2,self.group_planes], dim=0)qr = torch.einsum('bgci,cij->bgij', q, q_embedding)kr = torch.einsum('bgci,cij->bgij', k, k_embedding).transpose(2, 3)qk = torch.einsum('bgci, bgcj->bgij', q, k)stacked_similarity = torch.cat([qk, qr, kr], dim=1)stacked_similarity = self.bn_similarity(stacked_similarity).view(N * W, 3, self.groups, H, H).sum(dim=1)# stacked_similarity = self.bn_qr(qr) + self.bn_kr(kr) + self.bn_qk(qk)# (N, groups, H, H, W)similarity = F.softmax(stacked_similarity, dim=3)sv = torch.einsum('bgij,bgcj->bgci', similarity, v)sve = torch.einsum('bgij,cij->bgci', similarity, v_embedding)stacked_output = torch.cat([sv, sve], dim=-1).view(N * W, self.out_planes * 2, H)output = self.bn_output(stacked_output).view(N, W, self.out_planes, 2, H).sum(dim=-2)if self.width:output = output.permute(0, 2, 1, 3)else:output = output.permute(0, 2, 3, 1)if self.stride > 1:output = self.pooling(output)return outputdef reset_parameters(self):self.qkv_transform.weight.data.normal_(0, math.sqrt(1. / self.in_planes))# nn.init.uniform_(self.relative, -0.1, 0.1)nn.init.normal_(self.relative, 0., math.sqrt(1. / self.group_planes))if __name__ == '__main__':x = torch.randn(4, 512, 7, 7).cuda()# kernel_size 要跟 h,w 相同model = AxialAttention(512, 512, kernel_size=7).cuda()out = model(x)print(out.shape)

本文只是对论文中的即插即用模块做了整合,对论文中的一些地方难免有遗漏之处,如果想对这些模块有更详细的了解,还是要去读一下原论文,肯定会有更多收获。


http://www.ppmy.cn/server/146095.html

相关文章

mysql之基本常用的语法

mysql之基本常用的语法 1.增加数据2.删除数据3.更新/修改数据4.查询数据4.1.where子句4.2.order by4.3.limit与offset4.4.分组与having4.5.连接 5.创建表 1.增加数据 insert into 1.指定列插入 语法:insert into table_name(列名1,列名2,....,列名n) values (值1,值…

svn 崩溃、 cleanup失败 怎么办

在使用svn的过程中,可能出现整个svn崩溃, 例如cleanup 失败的情况,类似于 这时可以下载本贴资源文件并解压。 或者直接访问网站 SQLite Download Page 进行下载 解压后得到 sqlite3.exe 放到发生问题的svn根目录的.svn路径下 右键呼出pow…

【拥抱AI】Milvus 如何处理 TB 级别的大规模向量数据?

处理 TB 级别的大规模向量数据是 Milvus 的核心优势之一。Milvus 通过分布式架构、高效的索引算法和优化的数据管理策略来实现这一目标。下面将详细介绍 Milvus 如何处理 TB 级别向量数据的流程,包括插入代码示例、指令以及流程图。 1. 分布式架构 Milvus 使用分…

centos新建磁盘

1,fdisk -l 2,fdisk /dev/sdb 在fdisk交互界面中: 输入 n - 创建新分区 输入 p - 创建主分区 分区号按回车使用默认值1 起始扇区按回车使用默认值 结束扇区按回车使用默认值(这样会使用所有可用空间) 输入 w - 保存并退…

计算机网络安全 —— 非对称加密算法 RSA 和数字签名

一、非对称加密算法基本概念# ​ 在对称密钥系统中,两个参与者要共享同一个秘密密钥。但怎样才能做到这一点呢?一种是事先约定,另一种是用信使来传送。在高度自动化的大型计算机网络中,用信使来传送密钥显然是不合适的。如果事先…

【大数据测试之:RabbitMQ消息列队测试-发送、接收、持久化、确认、重试、死信队列并处理消息的并发消费、负载均衡、监控等】详细教程---保姆级

RabbitMQ消息列队测试教程 一、环境准备1. 安装 RabbitMQ2. 安装 Python 依赖 二、基本消息队列中间件实现1. 消息发送模块2. 消息接收模块 三、扩展功能1. 消息持久化和队列持久化2. 消息优先级3. 死信队列(DLQ) 四、并发处理和负载均衡1. 使用 Python …

Qt—QLineEdit 使用总结

文章参考:Qt—QLineEdit 使用总结 一、简述 QLineEdit是一个单行文本编辑控件。 使用者可以通过很多函数,输入和编辑单行文本,比如撤销、恢复、剪切、粘贴以及拖放等。 通过改变 QLineEdit 的 echoMode() ,可以设置其属性,比如以密码的形式输入。 文本的长度可以由 m…

.npmrc文件的用途

.npmrc 文件是 npm(Node.js 的包管理工具)用于配置项目或用户的设置文件。它可以存储与 npm 相关的配置信息,如注册表地址、认证信息、代理设置、安装路径等。.npmrc 文件可以出现在不同的地方,具有不同的作用范围,通常…