每日Attention学习15——Cross-Model Grafting Module

embedded/2024/9/24 10:17:24/
模块出处

[CVPR 22] [link] [code] Pyramid Grafting Network for One-Stage High Resolution Saliency Detection


模块名称

Cross-Model Grafting Module (CMGM)


模块作用

Transformer与CNN之间的特征融合


模块结构

在这里插入图片描述


模块思想

Transformer在全局特征上更优,CNN在局部特征上更优,对这两者进行进行融合的最简单做法是直接相加或相乘。但是,相加或相乘本质上属于"局部"操作,如果某片区域两个特征的不确定性都较高,则会带来许多噪声。为此,本文提出了CMGM模块,通过交叉注意力的形式引入更为广泛的信息来增强融合效果。


模块代码
import torch.nn.functional as F
import torch.nn as nn
import torchclass CMGM(nn.Module):def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None):super().__init__()self.num_heads = num_headshead_dim = dim // num_headsself.scale = qk_scale or head_dim ** -0.5self.k = nn.Linear(dim, dim , bias=qkv_bias)self.qv = nn.Linear(dim, dim * 2, bias=qkv_bias)self.proj = nn.Linear(dim, dim)self.act = nn.ReLU(inplace=True)self.conv = nn.Conv2d(8,8,kernel_size=3, stride=1, padding=1)self.lnx = nn.LayerNorm(64)self.lny = nn.LayerNorm(64)self.bn = nn.BatchNorm2d(8)self.conv2 = nn.Sequential(nn.Conv2d(64,64,kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),nn.ReLU(inplace=True),nn.Conv2d(64,64,kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),nn.ReLU(inplace=True))def forward(self, x, y):batch_size = x.shape[0]chanel     = x.shape[1]sc = xx = x.view(batch_size, chanel, -1).permute(0, 2, 1)sc1 = xx = self.lnx(x)y = y.view(batch_size, chanel, -1).permute(0, 2, 1)y = self.lny(y)B, N, C = x.shapey_k = self.k(y).reshape(B, N, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)x_qv= self.qv(x).reshape(B,N,2,self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)x_q, x_v = x_qv[0], x_qv[1] y_k = y_k[0]attn = (x_q @ y_k.transpose(-2, -1)) * self.scaleattn = attn.softmax(dim=-1)x = (attn @ x_v).transpose(1, 2).reshape(B, N, C)x = self.proj(x)x = (x+sc1)x = x.permute(0,2,1)x = x.view(batch_size,chanel,*sc.size()[2:])x = self.conv2(x)+xreturn x, self.act(self.bn(self.conv(attn+attn.transpose(-1,-2))))if __name__ == '__main__':x = torch.randn([1, 64, 11, 11])y = torch.randn([1, 64, 11, 11])cmgm = CMGM(dim=64)out1, out2 = cmgm(x, y)print(out1.shape)  # out feature 1, 64, 11, 11print(out2.shape)  # cross attention matrix 1, 8, 121, 121


http://www.ppmy.cn/embedded/100823.html

相关文章

【jvm】栈是否存在垃圾回收

目录 一、栈的特点1.1 栈内存分配1.2 栈的生命周期1.3 垃圾回收不直接涉及 二、堆与栈的区别三、总结 一、栈的特点 1.1 栈内存分配 1.栈内存分配是自动的,不需要程序员手动分配和释放。 2.每当一个方法被调用时,JVM就会在这个线程的栈上创建一个新的栈…

设计模式六大原则中的里氏替换原则

设计模式六大原则中的里氏替换原则(Liskov Substitution Principle, LSP)是面向对象设计中一个至关重要的原则,它定义了继承的基本原则和约束,确保子类能够透明地替换父类,而不会破坏系统的正确性和稳定性。以下是对里…

Python类与对象篇(七)

python 面向对象编程类与对象类的属性与方法构造函数与析构函数继承与多态封装与私有属性 面向对象编程 Python 的面向对象编程(Object-Oriented Programming, OOP)是一种编程风格,它将数据(属性)和功能(方法)封装在称为类(class)的结构中。这样做的主要目的是为了提高代码的可…

RongCallKit iOS 端本地私有 pod 方案

RongCallKit iOS 端本地私有 pod 方案 需求背景 适用于源码集成 CallKit 时,使用 pod 管理 RTC framework 以及源码。集成 CallKit 时,需要定制化修改 CallKit 的样式以及部分 UI 功能。适用于 CallKit 源码 Debug 调试便于定位相关问题。 解决方案 从…

【RabbitMQ】高级特性

本文将介绍一些RabbitMQ的重要特性。 官方文档:Protocol Extensions | RabbitMQ 本文是使用的Spring整合RabbitMQ环境。 生产者发送确认(publish confirm) 当消息发送给消息队列,如何确保消息队列一定收到消息呢,RabbitMQ通过 事务机制 和 …

spring揭秘09-aop03-aop织入器织入横切逻辑与自动织入

文章目录 【README】【1】spring aop的织入【1.1】使用ProxyFactory 作为织入器【1.2】基于接口的代理(JDK动态代理,目标类实现接口)【补充】 【1.2】基于类的代理(CGLIB动态代理,目标类没有实现接口)【1.2…

河南萌新2024第五场

A 日历游戏 题目大意: alice,bob玩游戏,给定一个2000.1.1到2024.8.1之间的任意一个日期,每次进行一次操作(保证合法日期) 天数1,例如2000.1.1 -> 2000.1.2 月份1,例如2000.1.…

极速文件预览!轻松部署 kkFileView 于 Docker 中!

大家好,这几天闲的难受,决定给自己找点事做。博主的项目中有个文件预览的小需求,原有方案是想将文件转换成 PDF 进行预览。本着能借鉴就绝对不自己写的原则。今天就让我们简单试用一下 kkFileView 文件预览服务,一起探索它的强大功…