6、关于Medical-Transformer

devtools/2024/12/23 9:03:41/

6、关于Medical-Transformer

Axial-Attention原文链接:Axial-attention
Medical-Transformer原文链接:Medical-Transformer

Medical-Transformer实际上是Axial-Attention在医学领域的运行,只是在这基础上增加了门机制,实际上也就是在原来Axial-attention基础之上增加权重机制,虚弱位置信息对于数据的影响,发现虚弱之后的效果比Axial-Attention机制效果更好

Axial-Attention

Axial-Attention与传统Transformer的self-attention相比较,将2D计算转成1D计算,Axial-attention机制,对于qkv的计算,做出了简化,仅仅某个点的横竖两个方向上的特殊,同时在qkv的基础上加上了各自位置特征,这些特征都是更新学习的。

Axial-attention模型架构图

左图为传统的self-attention机制,右图为Axial-attention机制,对于qkv都加上rq,rk,rv这样的位置参数,这些参数都是可以更新的,也就是说,每个的q在和自己对应的横竖轴反向进行计算的时候,q会和自己rq先进行权重计算,同样的k和v也会进行同样的计算,随后进行q和k进行计算得到权重,计算过程和原来的self-attention机制是一样的。

在这里插入图片描述

class AxialAttention(nn.Module):def forward(self, x):# 前向传播函数# 如果设置了 width 参数,调整张量维度顺序if self.width:x = x.permute(0, 2, 1, 3)  # 调整维度顺序else:x = x.permute(0, 3, 1, 2)  # N, W, C, H  调整为 N, C, H, WN, W, C, H = x.shape  # 获取张量形状x = x.contiguous().view(N * W, C, H)  # 重新调整形状,合并 N 和 W 维度# 通过x获得对应的qkv 批归一化后计算 qkvqkv = self.bn_qkv(self.qkv_transform(x))  q, k, v = torch.split(qkv.reshape(N * W, self.groups, self.group_planes * 2, H),[self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=2)  # 将 qkv 拆分为 q, k, v# 计算位置嵌入all_embeddings = torch.index_select(self.relative, 1, self.flatten_index).view(self.group_planes * 2, self.kernel_size, self.kernel_size)q_embedding, k_embedding, v_embedding = torch.split(all_embeddings, [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=0)  # 拆分嵌入# 计算 QR, KR, QK 相似性,分别计算得出rq,rkqr = torch.einsum('bgci,cij->bgij', q, q_embedding)  # QR: q 和 q_embedding 的爱因斯坦求和kr = torch.einsum('bgci,cij->bgij', k, k_embedding).transpose(2, 3)  # KR: k 和 k_embedding 的爱因斯坦求和,并转置# q和k进行计算,得到最后的权重qk = torch.einsum('bgci, bgcj->bgij', q, k)  # QK: q 和 k 之间的点积# 将 QR, KR, QK 相似性进行堆叠,连在一起进行计算stacked_similarity = torch.cat([qk, qr, kr], dim=1)  # 将 qk, qr, kr 连接起来stacked_similarity = self.bn_similarity(stacked_similarity).view(N * W, 3, self.groups, H, H).sum(dim=1)  # 批归一化并调整形状# similarity为q和k计算得出权重关系similarity = F.softmax(stacked_similarity, dim=3)  # 在第 3 维度上计算 softmax# 将q和v计算出来权重和v加权求和sv = torch.einsum('bgij,bgcj->bgci', similarity, v)  # 将相似度与 v 进行求和# v与位置信息结合sve = torch.einsum('bgij,cij->bgci', similarity, v_embedding)  # 将similarity与 v_embedding 进行求和# 将位置加权后的v和q和k计算结果与v加权的合并,并调整形状输出stacked_output = torch.cat([sv, sve], dim=-1).view(N * W, self.out_planes * 2, H)  # 合并 sv 和 sve,并调整形状output = self.bn_output(stacked_output).view(N, W, self.out_planes, 2, H).sum(dim=-2)  # 批归一化并调整形状# 恢复维度顺序if self.width:output = output.permute(0, 2, 1, 3)  # 调整维度顺序else:output = output.permute(0, 2, 3, 1)  # 调整维度顺序# 如果步长大于 1,应用池化操作if self.stride > 1:output = self.pooling(output)  # 池化return output  # 返回输出
横竖轴计算过程

先通过卷积把特征图缩小,然后横竖轴计算时,是将横轴一起进行计算,然后再进行纵轴计算的,完成计算后,通过1x1卷积将特征图还原为原来的大小,在传入下一层进行计算。

在这里插入图片描述

Medical-Transformer
Medical-Transformer架构图

Medical-Transformer实际就是Axial-attention在医学图像分割领域的应用,medical-tranformer大模型架构采用整个图像进行Axial-attention特征提取,同时也将图像分成多个窗口,对每个窗口进行axial-attention特征提取,窗口由于计算量小,可以多进行几层Axial-attention,最终将整个图像特征和窗口特征融合,完成整个的特征提取,值得一提的是在进行窗口Axial-attention时,qkv都没有加上位置编码(也就是下面部分的图像)。

在这里插入图片描述

主体架构
class medt_net(nn.Module):def _forward_impl(self, x):xin = x.clone()  # 保存输入数据的副本x = self.conv1(x)  # 第一个卷积层x = self.bn1(x)  # 第一个批归一化层x = self.relu(x)  # ReLU 激活函数x = self.conv2(x)  # 第二个卷积层x = self.bn2(x)  # 第二个批归一化层x = self.relu(x)  # ReLU 激活函数x = self.conv3(x)  # 第三个卷积层x = self.bn3(x)  # 第三个批归一化层x = self.relu(x)  # ReLU 激活函数x1 = self.layer1(x)  # 第一个残差层 实际上就是 Gated Axial Attention Layerx2 = self.layer2(x1)  # 第二个残差层 同样是 Gated Axial Attention Layer# 对输入进行插值放大,并通过解码器处理x = F.relu(F.interpolate(self.decoder4(x2), scale_factor=(2, 2), mode='bilinear'))x = torch.add(x, x1)  # 将放大的特征图与 x1 相加x = F.relu(F.interpolate(self.decoder5(x), scale_factor=(2, 2), mode='bilinear'))# 以上完成就是图上方整个图像的卷积过程# -------------------------------------------------------------------------------------------x_loc = x.clone()  # 生成一个本地副本# 下面对图像进行切分,分别对每个窗口进行局部处理,实际上是16个窗口for i in range(0, 4):for j in range(0, 4):x_p = xin[:, :, 32 * i:32 * (i + 1), 32 * j:32 * (j + 1)]  # 提取32x32的局部patch# 逐层卷积处理patchx_p = self.conv1_p(x_p)x_p = self.bn1_p(x_p)x_p = self.relu(x_p)x_p = self.conv2_p(x_p)x_p = self.bn2_p(x_p)x_p = self.relu(x_p)x_p = self.conv3_p(x_p)x_p = self.bn3_p(x_p)x_p = self.relu(x_p)# 进行四个x1_p = self.layer1_p(x_p)  # 第一个残差层(patch-wise) 这里进行的axial-attention在进行qkv计算时,qkv都没有加入位置信息计算x2_p = self.layer2_p(x1_p)  # 第二个残差层(patch-wise)x3_p = self.layer3_p(x2_p)  # 第三个残差层(patch-wise)x4_p = self.layer4_p(x3_p)  # 第四个残差层(patch-wise)# 对patch进行插值放大并通过解码器处理x_p = F.relu(F.interpolate(self.decoder1_p(x4_p), scale_factor=(2, 2), mode='bilinear'))x_p = torch.add(x_p, x4_p)  # 将放大的特征图与 x4_p 相加x_p = F.relu(F.interpolate(self.decoder2_p(x_p), scale_factor=(2, 2), mode='bilinear'))x_p = torch.add(x_p, x3_p)  # 将放大的特征图与 x3_p 相加x_p = F.relu(F.interpolate(self.decoder3_p(x_p), scale_factor=(2, 2), mode='bilinear'))x_p = torch.add(x_p, x2_p)  # 将放大的特征图与 x2_p 相加x_p = F.relu(F.interpolate(self.decoder4_p(x_p), scale_factor=(2, 2), mode='bilinear'))x_p = torch.add(x_p, x1_p)  # 将放大的特征图与 x1_p 相加x_p = F.relu(F.interpolate(self.decoder5_p(x_p), scale_factor=(2, 2), mode='bilinear'))x_loc[:, :, 32 * i:32 * (i + 1), 32 * j:32 * (j + 1)] = x_p  # 将局部处理后的结果放回原始位置# 将整个图片的axial-attention,和每个窗口得出的结果进行结合x = torch.add(x, x_loc)  # 将全局和局部特征进行融合x = F.relu(self.decoderf(x))  # 通过最终的解码器层x = self.adjust(F.relu(x))  # 调整输出return x  # 返回最终输出
Gated Axial Attention Layer

从架构图中可以看出,就是在Axial-attention的基础上,加上了门机制,说白了,也就是在qkv和各自的rq,rk,rv计算完成后,再进行下一步计算之前,进行了一个加权计算,虚弱了位置变量对特征提取结果的影响。

在这里插入图片描述

横向或纵向Gated Axial-attention过程

注意里面qr,kr实际上就是图片中的rq,rk,而

class AxialAttention_dynamic(nn.Module):def forward(self, x):# 判断是否需要对宽度维度进行变换if self.width:x = x.permute(0, 2, 1, 3)  # 交换维度顺序,形状变为 [N, C, W, H]else:x = x.permute(0, 3, 1, 2)  # 交换维度顺序,形状变为 [N, W, C, H]N, W, C, H = x.shape  # 获取输入张量的形状x = x.contiguous().view(N * W, C, H)  # 将张量变形为 [N * W, C, H]print(x.shape)  # 输出形状: [64, 16, 64]# 变换操作qkv = self.bn_qkv(self.qkv_transform(x))  # 对qkv进行批归一化print(qkv.shape)  # 输出形状: [64, 32, 64]# 将qkv张量拆分为q、k、v,分别表示查询、键和值q, k, v = torch.split(qkv.reshape(N * W, self.groups, self.group_planes * 2, H), [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=2)print(q.shape)  # 输出q的形状: [64, 8, 1, 64]print(k.shape)  # 输出k的形状: [64, 8, 1, 64]print(v.shape)  # 输出v的形状: [64, 8, 2, 64],v有两份# 计算位置嵌入all_embeddings = torch.index_select(self.relative, 1, self.flatten_index).view(self.group_planes * 2, self.kernel_size, self.kernel_size)print(all_embeddings.shape)  # 输出嵌入的形状: [4, 64, 64],共有4份q_embedding, k_embedding, v_embedding =torch.split(all_embeddings, [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=0)print(q_embedding.shape)  # 输出q的位置嵌入形状: [1, 64, 64]print(k_embedding.shape)  # 输出k的位置嵌入形状: [1, 64, 64]print(v_embedding.shape)  # 输出v的位置嵌入形状: [2, 64, 64],v有两份位置编码# 计算q与位置嵌入的乘积qr = torch.einsum('bgci,cij->bgij', q, q_embedding)print(qr.shape)  # 输出qr的形状: [64, 8, 64, 64]# 计算k与位置嵌入的乘积,并进行转置kr = torch.einsum('bgci,cij->bgij', k, k_embedding).transpose(2, 3)print(kr.shape)  # 输出kr的形状: [64, 8, 64, 64]# 计算q和k的点积qk = torch.einsum('bgci, bgcj->bgij', q, k)print(qk.shape)  # 输出qk的形状: [64, 8, 64, 64]# 对qr和kr进行初始化,使用self.f_qr和self.f_kr作为初始化的权重qr = torch.mul(qr, self.f_qr)print(qr.shape)  # 输出qr的形状: [64, 8, 64, 64]kr = torch.mul(kr, self.f_kr)print(kr.shape)  # 输出kr的形状: [64, 8, 64, 64]# 将qk、qr和kr拼接起来stacked_similarity = torch.cat([qk, qr, kr], dim=1)print(stacked_similarity.shape)  # 输出拼接后的形状: [64, 24, 64, 64]# 进行批归一化,重新变形为[N * W, 3, groups, H, H],并对维度1求和stacked_similarity = self.bn_similarity(stacked_similarity).view(N * W, 3, self.groups, H, H).sum(dim=1)print(stacked_similarity.shape)  # 输出归一化后的形状: [64, 8, 64, 64]# 计算相似度similarity = F.softmax(stacked_similarity, dim=3)print(similarity.shape)  # 输出相似度的形状: [64, 8, 64, 64]# 使用相似度与v相乘,获得加权后的值sv = torch.einsum('bgij,bgcj->bgci', similarity, v)print(sv.shape)  # 输出加权后的形状: [64, 8, 2, 64]# 使用相似度与v的位置嵌入相乘sve = torch.einsum('bgij,cij->bgci', similarity, v_embedding)print(sve.shape)  # 输出位置嵌入加权后的形状: [64, 8, 2, 64]# 对sv和sve进行初始化sv = torch.mul(sv, self.f_sv)print(sv.shape)  # 输出sv的形状: [64, 8, 2, 64]sve = torch.mul(sve, self.f_sve)print(sve.shape)  # 输出sve的形状: [64, 8, 2, 64]# 将sv和sve拼接在一起,并重新变形为[N * W, out_planes * 2, H]stacked_output = torch.cat([sv, sve], dim=-1).view(N * W, self.out_planes * 2, H)print(stacked_output.shape)  # 输出拼接后的形状: [64, 32, 64]# 进行批归一化,并变形为[N, W, out_planes, 2, H],对维度-2求和output = self.bn_output(stacked_output).view(N, W, self.out_planes, 2, H).sum(dim=-2)print(output.shape)  # 输出归一化后的形状: [1, 64, 16, 64]# 根据宽度调整维度顺序if self.width:output = output.permute(0, 2, 1, 3)else:output = output.permute(0, 2, 3, 1)print(output.shape)  # 输出最终的形状: [1, 16, 64, 64]# 如果步幅大于1,进行池化操作if self.stride > 1:output = self.pooling(output)return output

http://www.ppmy.cn/devtools/107414.html

相关文章

Mac 安装 jdk 8详细教程

Mac 电脑上安装Jdk 8 的步骤很简单,不用想Windows那样需要配置环境变量PATH、JAVA_HOME。 具体方法如下: 首先,去JDK官网下载对应版本的JDK 8。 这里需要注册一个账号,然后,账号下载。 下载完后,得到一个…

应用商店优化(ASO)的四大误区

应用商店优化 (ASO) 是移动营销中最重要的部分之一,可以帮助开发人员吸引自然流量并在应用推广方面取得预期效果。近年来ASO优化在开发者中越来越受欢迎。虽然它已经证明了其有效性和对应用成功的影响力,但尽管如此仍然存在与ASO相关的误解,导…

vscode 远程SSH连接并配置C/C++开发环境

服务器配置 生成用户密钥 ssh-keygen -m PEM -t rsa -b 4096 执行上面的命令后会在 ~/.ssh/ 目录生成密钥,然后导入密钥到认证文件中 cd .ssh/ cat id_rsa.pub >> authorized_keys最后将 id_rsa 传输到宿主机上 宿主机配置 安装插件 配置插件 安装图示…

macos 10.15 catalina xcode 下载和安装

在macos 10.15 catalina系统中, 由于系统已经不再支持,所以我们无法通过应用商店来安装xcode, 需要手动下载指定版本的 xcode 版本才能安装, catalina 支持的最新xcode版本为 Xcode v12.4 (12D4e) , 其他的新版本是无法安装在Catalina系统中的. Xcode_12.4.xip下载地址 注意,下…

【Tomcat源码分析 】“深入探索:Tomcat 类加载机制揭秘“

前言 在探究 Tomcat 类加载机制之前,让我们重温一下 Java 默认的类加载器,加深对其的理解。 如同作者在《深入理解 Java 虚拟机》第二版中所言,类加载机制对于理解 Java 运行时环境至关重要。 什么是类加载机制 Java 虚拟机将描述类的字节…

【日常记录-Java】SpringBoot中使用无返回值的异步方法

Author:赵志乾 Date:2024-09-05 Declaration:All Right Reserved!!! 1. 简介 在SpringBoot中,使用Async注解可以很方便地标记一个方法为异步执行。好处是调用者无需等待这些方法完成便可继续执…

【conda】导出和重建 Conda 环境

目录 1. 导出 Conda 环境1.1 激活环境1.2 导出环境配置1.3 检查和编辑环境配置文件(可选)1.4 共享或重建环境 2. 常见问题及解决方案2.1 导出环境时出现 “PackagesNotFoundError”2.2 导出的 environment.yml 文件在其他系统上无法使用2.3 导出的环境文…

【软件】软件评审

目录 1. 说明2. 设计质量的评审内容3. 程序质量的评审内容3.1 软件结构3.2 功能的通用性3.3 模块的层次3.4 模块结构3.4 处理过程的结构 4. 与运行环境的接口5. 例题5.1 例题1 1. 说明 1.通常,把“质量”理解为“用户满意程度”。为了使得用户满意,有两…