每日Attention学习27——Patch-based Graph Reasoning

news/2025/3/20 23:22:14/
模块出处

[NC 25] [link] Graph-based context learning network for infrared small target detection


模块名称

Patch-based Graph Reasoning (PGR)


模块结构

在这里插入图片描述


模块特点
  • 使用图结构更好的捕捉特征的全局上下文
  • 将图结构与特征切片(Patching)相结合,从而促进全局/局部特征互补

模块代码
import torch
import torch.nn as nn
import torch.nn.functional as Fclass graph(nn.Module):def __init__(self, p2=4, nIn=64, N=16):super(graph, self).__init__()self.p2 = p2self.N = Nself.conv30 = nn.Sequential(nn.Conv2d(nIn, self.N, kernel_size=3, stride=1, padding=1, groups=1),nn.ReLU(inplace=True))self.conv10 = nn.Sequential(nn.Conv1d(nIn, nIn, kernel_size=1, stride=1, padding=0),nn.ReLU(inplace=True))self.conv11 = nn.Sequential(nn.Conv1d(self.N, self.N, kernel_size=1, stride=1, padding=0),nn.ReLU(inplace=True))self.adaptivemax = nn.AdaptiveAvgPool2d((8, 8))self.conv12 = nn.Sequential(nn.Conv1d(p2 ** 2, p2, kernel_size=1, stride=1, padding=0),nn.ReLU(inplace=True),nn.Conv1d(p2, p2, kernel_size=1, stride=1, padding=0),nn.ReLU(inplace=True),nn.Conv1d(p2, p2 ** 2, kernel_size=1, stride=1, padding=0),nn.Sigmoid())def ADP_weight(self, x):b, C, H, W = x.shapefg = self.adaptivemax(x)  fg1 = fg.view(b, C, self.p2 ** 2)  fg1 = torch.transpose(fg1, 1, 2)  fg2 = self.conv12(fg1)  fg3 = fg2.unsqueeze(-1).unsqueeze(-1)return fg3def graph_convolution(self, fs, x):b, C, H, W = x.shapeh, w = H // self.p2, W // self.p2L = h * wB = self.conv30(fs)  B1 = B.view(-1, self.N, L)  fs1 = fs.view(-1, C, L)  fs1 = torch.transpose(fs1, 1, 2) fs2 = torch.bmm(B1, fs1)  fs3 = self.conv11(fs2)  fs5 = self.conv10(torch.transpose(fs3, 1, 2))  fs6 = torch.bmm(torch.transpose(B1, 1, 2), torch.transpose(fs5, 1, 2))fs6 = torch.transpose(fs6, 1, 2) fs6 = fs6.view(b, self.p2 ** 2, C, h, w) return fs6def forward(self, fs, x):fs6 = self.graph_convolution(fs, x)weight = self.ADP_weight(x)out = weight * fs6return outclass PGR(nn.Module):def __init__(self, p2=4, nIn=32, nOut=32, add=True):super(PGR, self).__init__()self.p2 = p2self.N = nIn // 4self.add = addself.graph0 = graph(p2, nIn, self.N)self.conv31 = nn.Sequential(nn.Conv2d(nOut, nOut, kernel_size=1, stride=1),nn.BatchNorm2d(nOut),nn.ReLU(inplace=True))def forward(self, x):b, C, H, W = x.shapeh, w = H // self.p2, W // self.p2L = h * wfs = torch.zeros((b, self.p2 ** 2, C, h, w)).cuda()for i in range(1, self.p2 + 1):for j in range(1, self.p2 + 1):fs[:, i * j - 1, :, :, :] = x[:, :, (i - 1) * h: i * h, (j - 1) * w: j * w]fs = fs.view(b * self.p2 ** 2, C, h, w)fs6 = self.graph0(fs, x)out = torch.zeros_like(x)for i in range(1, self.p2 + 1):for j in range(1, self.p2 + 1):out[:, :, (i - 1) * h: i * h, (j - 1) * w: j * w] = fs6[:, i * j - 1, :, :, :]out = self.conv31(out)if self.add:out = out + xreturn outif __name__ == '__main__':x = torch.randn([1, 64, 44, 44]).cuda()pgr = PGR(p2=8, nIn=64, nOut=64).cuda()out = pgr(x)print(out.shape) # [1, 64, 44, 44]


http://www.ppmy.cn/news/1580711.html

相关文章

ArcGIS10. 8简介与安装,附下载地址

目录 ArcGIS10.8 1. 概述 2. 组成与功能 3. 10.8 特性 下载链接 安装步骤 1. 安装准备 2. 具体步骤 3.补丁 其他版本安装 ArcGIS10.8 1. 概述 ArcGIS 10.8 是由美国 Esri 公司精心研发的一款功能强大的地理信息系统(GIS)平台。其核心功能在于…

Java基础关键_025_IO流(三)

目 录 一、数据输入输出流 1.DataOutputStream 2.DataInputStream 二、序列化和反序列化 1.ObjectOutputStream 2.ObjectInputStream 3.Serializable 接口 (1)说明 (2)实例 4.序列化版本号 (1)…

蓝桥杯练习day1:拆分数位-四位数字的最小和

前言 给你一个四位 正 整数 num 。请你使用 num 中的 数位 ,将 num 拆成两个新的整数 new1 和 new2 。new1 和 new2 中可以有 前导 0 ,且 num 中 所有 数位都必须使用。 比方说,给你 num 2932 ,你拥有的数位包括:两…

Matlab 液位系统根据输入和输出信号拟合一阶传递函数

1、内容简介 略 Matlab165-液位系统根据输入和输出信号拟合一阶传递函数 可以交流、咨询、答疑 2、内容说明 略 3、仿真分析 略 4、参考论文 略

【C++】二叉树和堆的链式结构(上)

本篇博客给大家带来的是用C语言来实现堆链式结构和二叉树的实现! 🐟🐟文章专栏:数据结构 🚀🚀若有问题评论区下讨论,我会及时回答 ❤❤欢迎大家点赞、收藏、分享! 今日思想&#xff…

golang快速上手基础语法

变量 第一种,指定变量类型,声明后若不赋值,使用默认值0 package mainimport "fmt"func main() {var a int //第一种,指定变量类型,声明后若不赋值,使用默认值0。fmt.Printf(" a %d\n"…

RHCE(RHCSA复习:npm、dnf、源码安装实验)

七、软件管理 7.1 rpm 安装 7.1.1 挂载 [rootlocalhost ~]# ll /mnt total 0 drwxr-xr-x. 2 root root 6 Oct 27 21:32 hgfs[rootlocalhost ~]# mount /dev/sr0 /mnt #挂载 mount: /mnt: WARNING: source write-protected, mounted read-only. [rootlocalhost ~]# [rootlo…

Pytest项目_day01(HTTP接口)

HTTP HTTP是一个协议(服务器传输超文本到浏览器的传送协议),是基于TCP/IP通信协议来传输数据(HTML文件,图片文件,查询结果等)。 访问域名 例如www.baidu.com就是百度的域名,我们想…