GDFN模块(restormer)

news/2024/11/13 3:39:33/

为了对特征进行变换,常规的前馈神经网络独立地在每个像素位置进行相同的操作。它使用两个1x1卷积层,一个用来扩展特征通道(通常4倍),第二个用来将特征通道减少到原来的输入维度。在隐藏层中加入非线性。

GDFN做了两个基本的改进:

  • 门机制
  • 深度可分离卷积

架构如下图:

门机制通过经过线性变换的两个平行通道的逐元素点积实现,其中一个通道用GELU激活,可以参考Gaussian Error Linear Units (GELUs)

深度可分离卷积用来编码空间上邻域像素的信息,有助于学习局部图像结构。

Given an input tensor X ∈ R H ^ × W ^ × C ^ \mathbf{X} \in \mathbb{R}^{\hat{H} \times \hat{W} \times \hat{C}} XRH^×W^×C^, GDFN is formulated as:
X ^ = W p 0 Gating  ( X ) + X Gating ⁡ ( X ) = ϕ ( W d 1 W p 1 ( LN ⁡ ( X ) ) ) ⊙ W d 2 W p 2 ( LN ⁡ ( X ) ) \begin{aligned} \hat{\mathbf{X}} & =W_p^0 \text { Gating }(\mathbf{X})+\mathbf{X} \\ \operatorname{Gating}(\mathbf{X}) & =\phi\left(W_d^1 W_p^1(\operatorname{LN}(\mathbf{X}))\right) \odot W_d^2 W_p^2(\operatorname{LN}(\mathbf{X})) \end{aligned} X^Gating(X)=Wp0 Gating (X)+X=ϕ(Wd1Wp1(LN(X)))Wd2Wp2(LN(X))where ⊙ \odot denotes element-wise multiplication, ϕ \phi ϕ represents the GELU non-linearity, and LN is the layer normalization

实现代码

## Gated-Dconv Feed-Forward Network (GDFN)
class FeedForward(nn.Module):def __init__(self, dim, ffn_expansion_factor, bias):super(FeedForward, self).__init__()hidden_features = int(dim*ffn_expansion_factor)self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2,\kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)def forward(self, x):x = self.project_in(x)x1, x2 = self.dwconv(x).chunk(2, dim=1)x = F.gelu(x1) * x2x = self.project_out(x)return x

上面这段代码实现的是(b)图从 Norm模块后面开始到残差连接之前的线性过程。

在这里插入图片描述
在封装成Transformer Block的时候用了层标准化以及残差连接,如图(b)所示


http://www.ppmy.cn/news/1059790.html

相关文章

React中的props和state的理解

props: props是一个从外部传递进组件的参数。由于React具有单向数据流的特性,所以它的主要作用是从父组件向子组件中传递数据,它是不可改变的,如果想要改变它,只能通过外部组件传入新的props来从新渲染子组件&#xff…

铜矿人员定位安全方案

针对铜矿中的人员定位安全需求,可以采用以下方案: 1.实时人员定位系统:建立一个实时人员定位系统,通过在矿工的工作服或安全帽上安装UWB或RFID定位设备,以及相应的接收器和基站,实时跟踪和定位矿工的位置。…

Golang Gorm 一对多关系 关系表创建

一对多关系 我们先从一对多开始多表关系的学习因为一对多的关系生活中到处都是,例如: 老板与员工女神和添狗老师和学生班级与学生用户与文章 在创建的时候先将没有依赖的创建。表名称ID就是外键。外键要和关联的外键的数据类型要保持一致。 package ma…

Go 语言的实战案例 SOCKS5 代理 | 青训营

Powered by:NEFU AB-IN 文章目录 Go 语言的实战案例 SOCKS5 代理 | 青训营 引入TCP echo serverauth 认证请求阶段relay阶段 Go 语言的实战案例 SOCKS5 代理 | 青训营 GO语言工程实践课后作业:实现思路、代码以及路径记录 引入 代理是指在计算机网络中&#xff…

【随笔】如何使用阿里云的OSS保存基础的服务器环境

使用阿里云OSS创建一个存储仓库:bucket 在Linux上下载并安装阿里云的ossutil工具 // 命令行,是linux环境 3. 安装ossutil。sudo -v ; curl https://gosspublic.alicdn.com/ossutil/install.sh | sudo bash 说明:安装过程中,需要使用解压工具…

iptables的使用规则

环境中为了安全要限制swagger的访问,最简单的方式是通过iptables防火墙设置规则限制。 在测试服务器中设置访问swagger-ui.html显示如下,区分大小写: iptables设置限制访问9783端口的swagger字段的请求: iptables -A INPUT -p t…

【图像分割】实现snake模型的活动轮廓模型以进行图像分割研究(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

opencv 案例实战02-停车场车牌识别SVM模型训练及验证

1. 整个识别的流程图: 2. 车牌定位中分割流程图: 三、车牌识别中字符分割流程图: 1.准备数据集 下载车牌相关字符样本用于训练和测试,本文使用14个汉字样本和34个数字跟字母样本,每个字符样本数为40,样本尺…