12.16 深度学习-混合注意力CBAM

server/2024/12/22 1:40:55/

# 混合注意力机制(Hybrid Attention Mechanism)是一种结合空间和通道注意力的策略,旨在提高神经网络的特征提取能力。

# 空间和通道都加上去

# CBAM是一种轻量级的注意力模块,它通过增加空间和通道两个维度的注意力,来提高模型的性能。

# 在某个阶段 先后加入通道和空间

import torch

import torch.nn as nn

# CBAM 混合注意力 方法 的实现

# 通道注意力构建

class ChannelAtt(nn.Module):

    def __init__(self,c,r= 16,*args, **kwargs):

        super().__init__(*args, **kwargs)

        self.max=nn.Sequential(nn.AdaptiveMaxPool2d(1),nn.ReLU())

        self.avg=nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.ReLU())

        # 感知机 两个池化完成之后 得到通道系数的过程 这两个结果是共用的

        self.perceptron=nn.Sequential(nn.Linear(c,c//r),nn.ReLU(),nn.Linear(c//r,c))

        self.activate=nn.Sigmoid()

    def forward(self,x):

        x1=self.max(x)

        x2=self.avg(x)

        x1=self.perceptron(x1.view(x1.shape[0],-1))

        x2=self.perceptron(x2.view(x2.shape[0],-1))

        att=self.activate(x1+x2)

        att=att.unsqueeze(2).unsqueeze(3)

        return x*att

# 空间注意力构建

class SpaceAtt(nn.Module):

    def __init__(self,kernel_size=7, *args, **kwargs):

        super().__init__(*args, **kwargs)

        self.fc=nn.Sequential(nn.Conv2d(in_channels=2,out_channels=1,kernel_size=kernel_size,padding=kernel_size//2),nn.Sigmoid())

       

    def forward(self,x):

        x1=torch.max(x,dim=1,keepdim=True)# 当 keepdim=True 时,计算均值后,输出张量将保持被约简的维度,但该维度的大小将为 1。也就是说,结果张量会保留原始张量的形状结构,只是被约简的维度变为 1。

        x2=torch.mean(x,dim=1,keepdim=True)

        att=self.fc(torch.cat((x1.values,x2),dim=1))

        return x*att

class CBAM(nn.Module):

    def __init__(self, c,r=16,*args, **kwargs):

        super().__init__(*args, **kwargs)

        self.channel_att=ChannelAtt(c,r)

        self.space_att=SpaceAtt()

    def forward(self,x):

        x=self.channel_att(x)

        x=self.space_att(x)

        return x

img=torch.rand(1,128,224,224)

cbam=CBAM(img.shape[1])

res=cbam(img)

print(res.shape)


http://www.ppmy.cn/server/152097.html

相关文章

word实现两栏格式公式居中,编号右对齐

1、确定分栏的宽度 选定一段文字 点击分栏:如本文的宽度为22.08字符 2、将公式设置为 两端对齐,首行无缩进。 将光标放在 公式前面 点击 格式-->段落-->制表位 在“制表位位置”输入-->11.04字符(22.08/211.04字符)&…

arcgisPro相接多个面要素转出为完整独立线要素

1、使用【面转线】工具,并取消勾选“识别和存储面邻域信息”,如下: 2、得到的线要素,如下:

DeepSeek-V2的多头潜在注意力机制及其在开源Mixture-of-Experts (MoE)语言模型中的应用

DeepSeek-V2的多头潜在注意力机制及其在开源Mixture-of-Experts (MoE)语言模型中的应用 DeepSeek-V2的架构及其优势 DeepSeek-V2的架构及其优势可以从几个关键方面进行深入探讨: 1. 架构设计 DeepSeek-V2是一个基于Mixture-of-Experts(MoE&#xff0…

Spring Boot + Dubbo 的整合 ,仅需六步

Spring Boot 与 Dubbo 的整合 1. 添加依赖 <dependencies><!-- Spring Boot Starter --><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter</artifactId></dependency><!-- Dubb…

Apache 如何监听多个端口 ?

Apache 是一个广泛使用的 web 服务器&#xff0c;可以配置为侦听多个端口。这对于托管多个网站、运行不同类型的服务或改进服务器的可访问性特别有用。在本文中&#xff0c;我们将探讨配置 Apache 以侦听多个端口的步骤。 Step 1: Access Apache Configuration File 找到并打…

免费GIS工具箱:轻松将glb文件转换成3DTiles文件

在GIS地理信息系统领域&#xff0c;GLB文件作为GLTF文件的二进制版本&#xff0c;主要用于3D模型数据的存储和展示。然而&#xff0c;GLB文件的使用频率相对较低&#xff0c;这是因为GIS系统主要处理的是地理空间数据&#xff0c;如地图、地形、地貌、植被、水系等&#xff0c;…

Day27 - 大模型微调,LLaMA搭建

指令微调 SFT&#xff1a;Supervised Fine - Tuning 自我认知 self-cognitionidentity私有知识 / 具体任务公共知识 LLaMA-Factory 搭建过程 1. 下载 LLaMA-Factory 源代码 ​git clone https://github.com/hiyouga/LLaMA-Factory.git 2. 安装 LLaMA-Factory 依赖包 cd L…

OOP面向对象编程:类与类之间的关系

OOP面向对象编程&#xff1a;类与类之间的关系 三大关系&#xff1a;复合&#xff08;适配器设计模式&#xff09;、委托&#xff08;桥接设计模式&#xff09;、继承 8、1复合Composition has-a -> 适配器模式 一个类里面含有另一个类的对象 —> 复合关系 has-a 适配器设…