注意力机制讲解与代码解析

news/2024/10/18 22:26:27/

一、SEBlock(通道注意力机制)

先在H*W维度进行压缩,全局平均池化将每个通道平均为一个值。
(B, C, H, W)---- (B, C, 1, 1)

利用各channel维度的相关性计算权重
(B, C, 1, 1) --- (B, C//K, 1, 1) --- (B, C, 1, 1) --- sigmoid

与原特征相乘得到加权后的。

import torch
import torch.nn as nnclass SELayer(nn.Module):def __init__(self, channel, reduction = 4):super(SELayer, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1) //自适应全局池化,只需要给出池化后特征图大小self.fc1 = nn.Sequential(nn.Conv2d(channel, channel//reduction, 1, bias = False),nn.ReLu(implace = True),nn.Conv2d(channel//reduction, channel, 1, bias = False),nn.sigmoid())def forward(self, x):y = self.avg_pool(x)y_out = self.fc1(y)return x * y

二、CBAM(通道注意力+空间注意力机制)

CBAM里面既有通道注意力机制,也有空间注意力机制。
通道注意力同SE的大致相同,但额外加入了全局最大池化与全局平均池化并行。

空间注意力机制:先在channel维度进行最大池化和均值池化,然后在channel维度合并,MLP进行特征交融。最终和原始特征相乘。 

import torch
import torch.nn as nnclass ChannelAttention(nn.Module):def __init__(self, channel, rate = 4):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.fc1 = nn.Sequential(nn.Conv2d(channel, channel//rate, 1, bias = False)nn.ReLu(implace = True)nn.Conv2d(channel//rate, channel, 1, bias = False)            )self.sig = nn.sigmoid()def forward(self, x):avg = sefl.avg_pool(x)avg_feature = self.fc1(avg)max = self.max_pool(x)max_feature = self.fc1(max)out = max_feature + avg_featureout = self.sig(out)return x * out

import torch
import torch.nn as nnclass SpatialAttention(nn.Module):def __init__(self):super(SpatialAttention, self).__init__()//(B,C,H,W)---(B,1,H,W)---(B,2,H,W)---(B,1,H,W)self.conv1 = nn.Conv2d(2, 1, kernel_size = 3, padding = 1, bias = False)self.sigmoid = nn.sigmoid()def forward(self, x):mean_f = torch.mean(x, dim = 1, keepdim = True)max_f = torch.max(x, dim = 1, keepdim = True)cat = torch.cat([mean_f, max_f], dim = 1)out = self.conv1(cat)return x*self.sigmod(out)

三、transformer里的注意力机制 

Scaled Dot-Product Attention

该注意力机制的输入是QKV。

1.先Q,K相乘。

2.scale

3.softmax

4.求output

 

import torch
import torch.nn as nnclass ScaledDotProductAttention(nn.Module):def __init__(self, scale):super(ScaledDotProductAttention, self)self.scale = scaleself.softmax = nn.softmax(dim = 2)def forward(self, q, k, v):u = torch.bmm(q, k.transpose(1, 2))u = u / scaleattn = self.softmax(u)output = torch.bmm(attn, v)return outputscale = np.power(d_k, 0.5)  //缩放系数为K维度的根号。
//Q  (B, n_q, d_q) , K (B, n_k, d_k)  V (B, n_v, d_v),Q与K的特征维度一定要一样。KV的个数一定要一样。

 MultiHeadAttention

将QKVchannel维度转换为n*C的形式,相当于分成n份,分别做注意力机制。

1.QKV单头变多头  channel ----- n * new_channel通过linear变换,然后把head和batch先合并

2.求单头注意力机制输出

3.维度拆分   将最终的head和channel合并。

4.linear得到最终输出维度

import torch
import torch.nn as nnclass MultiHeadAttention(nn.Module):def __init__(self, n_head, d_k, d_k_, d_v, d_v_, d_o):super(MultiHeadAttention, self)self.n_head = n_headself.d_k = d_kself.d_v = d_vself.fc_k = nn.Linear(d_k_, n_head * d_k)self.fc_v = nn.Linear(d_v_, n_head * d_v)self.fc_q = nn.Linear(d_k_, n_head * d_k)self.attention = ScaledDotProductAttention(scale=np.power(d_k, 0.5))self.fc_o = nn.Linear(n_head * d_v, d_0)def forward(self, q, k, v):batch, n_q, d_q_ = q.size()batch, n_k, d_k_ = k.size()batch, n_v, d_v_ = v.size()q = self.fc_q(q)k = self.fc_k(k)v = self.fc_v(v)q = q.view(batch, n_q, n_head, d_q).permute(2, 0, 1, 3).contiguous().view(-1, n_q, d_q)k = k.view(batch, n_k, n_head, d_k).permute(2, 0, 1, 3).contiguous().view(-1, n_k, d_k)v = v.view(batch, n_v, n_head, d_v).permute(2, 0, 1, 3).contiguous().view(-1. n_v, d_v)    output = self.attention(q, k, v)output = output.view(n_head, batch, n_q, d_v).permute(1, 2, 0, 3).contiguous().view(batch, n_q, -1)output = self.fc_0(output)return output

 


http://www.ppmy.cn/news/1103720.html

相关文章

如何写出一篇爆款产品文案,从目标受众到市场分析!

一篇爆款产品文案意味着什么?意味着更强的种草能力,更高的销售转化和更强的品牌传播力。今天来分享下如何写出一篇爆款产品文案,从目标受众到市场分析! 一、产品文案策略 一篇爆款产品文案,并不是一时兴起造就的。在撰写之前&…

uView实现全屏选项卡

// 直接复制粘贴即可使用 <template><view><view class"tabsBox"><u-tabs-swiper ref"uTabs" :list"list":current"current"change"tabsChange":is-scroll"false"></u-tabs-swiper&g…

【Ubuntu搭建MQTT Broker及面板+发布消息、订阅主题】

Ubuntu搭建MQTT Broker及面板发布消息、订阅主题 配置curl数据源 curl -s https://assets.emqx.com/scripts/install-emqx-deb.sh | sudo bash开始安装 sudo apt-get install emqx启动 sudo emqx start使用面板 根据自己的服务器是否开始了防火墙放行端口&#xff08;1808…

想在 Windows 上使用 telnet

如果你想在 Windows 上使用 telnet&#xff0c;可以按照以下步骤安装&#xff1a; 打开控制面板。点击 "程序" 或 "程序和功能"。点击 "启用或关闭 Windows 功能"。在弹出的窗口中找到 "Telnet 客户端" 并勾选它。点击 "确定&qu…

论文笔记:一分类及其在大数据中的潜在应用综述

0 概述 论文&#xff1a;A literature review on one‑class classification and its potential applications in big data 发表&#xff1a;Journal of Big Data 在严重不平衡的数据集中&#xff0c;使用传统的二分类或多分类通常会导致对具有大量实例的类的偏见。在这种情况…

软件测试下的AI之路(2)

&#x1f60f;作者简介&#xff1a;博主是一位测试管理者&#xff0c;同时也是一名对外企业兼职讲师。 &#x1f4e1;主页地址&#xff1a;【Austin_zhai】 &#x1f646;目的与景愿&#xff1a;旨在于能帮助更多的测试行业人员提升软硬技能&#xff0c;分享行业相关最新信息。…

无涯教程-JavaScript - IMSECH函数

描述 IMSECH函数以x yi或x yj文本格式返回复数的双曲正割。复数的双曲正割被定义为双曲余弦的倒数,即 六(z) 1/cosh(z) 语法 IMSECH (inumber)争论 Argument描述Required/OptionalInumberA complex number for which you want the hyperbolic secant.Required Notes Ex…

K8S1.23.6版本详细安装教程以及错误解决方案(包括前置环境,使用部署工具kubeadm来引导集群)

准备工作&#xff08;来自官方文档&#xff09; 一台兼容的 Linux 主机。Kubernetes 项目为基于 Debian 和 Red Hat 的 Linux 发行版以及一些不提供包管理器的发行版提供通用的指令。每台机器 2 GB 或更多的 RAM&#xff08;如果少于这个数字将会影响你应用的运行内存&#xf…