【Block总结】DynamicFilter,动态滤波器降低计算复杂度,替换传统的MHSA|即插即用

embedded/2025/1/31 17:45:45/

论文信息

标题: FFT-based Dynamic Token Mixer for Vision

论文链接: https://arxiv.org/pdf/2303.03932

关键词: 深度学习、计算机视觉、对象检测、分割

GitHub链接: https://github.com/okojoalg/dfformer

在这里插入图片描述

创新点

本论文提出了一种新的标记混合器(token mixer),称为动态滤波器(Dynamic Filter),旨在解决多头自注意力(MHSA)模型在处理高分辨率图像时的计算复杂度问题。传统的MHSA模型在输入特征图中像素数量的平方上具有计算复杂度,导致处理速度缓慢。通过引入基于快速傅里叶变换(FFT)的动态滤波器,论文展示了在保持性能的同时显著降低计算复杂度的可能性。

方法

论文中提出的动态滤波器结合了全局操作的优点,类似于MHSA,但在计算效率上更具优势。具体方法包括:

  • FFT-based Token Mixer: 通过FFT实现全局操作,降低计算复杂度。
  • DFFormer和CDFFormer模型: 这两种新型图像识别模型利用动态滤波器进行图像分类和其他下游任务。
    在这里插入图片描述

动态滤波器如何具体降低MHSA模型的计算复杂度?

动态滤波器通过引入基于快速傅里叶变换(FFT)的机制,显著降低了多头自注意力(MHSA)模型的计算复杂度。以下是其具体工作原理和优势:

计算复杂度问题

传统的MHSA模型在处理输入特征图时,其计算复杂度与特征图中像素数量的平方成正比。这意味着,当输入图像的分辨率增加时,计算需求会急剧上升,导致处理速度变慢,尤其是在高分辨率图像的情况下。

动态滤波器的工作原理

  1. 频域转换: 动态滤波器首先利用FFT将输入特征图转换到频域。FFT是一种高效的算法,可以将计算复杂度降低到 O ( N log ⁡ N ) O(N \log N) O(NlogN),其中 N N N是数据的长度。这一转换使得后续的操作可以在频域中进行,从而减少了计算量。

  2. 动态生成滤波器: 在频域中,动态滤波器通过一个多层感知机(MLP)动态生成每个特征通道的滤波器。这些滤波器是根据输入特征图的内容进行调整的,能够更好地捕捉到图像中的重要信息。

  3. 频域操作: 生成的滤波器在频域中应用于特征图,进行全局信息的捕捉。通过这种方式,动态滤波器能够有效地进行全局操作,同时避免了MHSA中计算复杂度的急剧增加。

  4. 逆FFT转换: 最后,经过滤波的频域特征图通过逆FFT转换回空间域,得到最终的输出结果。

优势

  • 降低计算复杂度: 通过在频域中进行操作,动态滤波器显著降低了MHSA模型的计算复杂度,使得处理高分辨率图像时的速度得以提升。

  • 提高内存效率: 动态滤波器的设计使得模型在处理时占用更少的内存,适合在资源有限的环境中运行。

  • 保持性能: 尽管计算复杂度降低,动态滤波器仍然能够保持与MHSA相似的性能,尤其是在图像分类和其他视觉任务中表现出色。

效果

实验结果表明,DFFormer和CDFFormer在高分辨率图像识别任务中表现出色,具有显著的吞吐量和内存效率。具体而言,这些模型在处理高分辨率图像时的性能优于传统的MHSA模型,显示出动态滤波器在实际应用中的潜力。

实验结果

论文通过一系列实验验证了提出模型的有效性,包括:

  • 图像分类: DFFormer和CDFFormer在标准数据集上的表现接近或超过了现有的最先进模型。
  • 下游任务分析: 通过对比实验,展示了动态滤波器在不同视觉任务中的适用性和优势。

总结

本论文的研究表明,基于FFT的动态滤波器是一种值得认真考虑的标记混合器选项,尤其是在处理高分辨率图像时。通过降低计算复杂度,动态滤波器不仅提高了模型的处理速度,还保持了良好的性能,推动了计算机视觉领域的进一步发展。研究结果为未来的视觉模型设计提供了新的思路和方向。

代码

import torch
import torch.nn as nn
from timm.models.layers import to_2tupleclass StarReLU(nn.Module):"""StarReLU: s * relu(x) ** 2 + b"""def __init__(self, scale_value=1.0, bias_value=0.0,scale_learnable=True, bias_learnable=True,mode=None, inplace=False):super().__init__()self.inplace = inplaceself.relu = nn.ReLU(inplace=inplace)self.scale = nn.Parameter(scale_value * torch.ones(1),requires_grad=scale_learnable)self.bias = nn.Parameter(bias_value * torch.ones(1),requires_grad=bias_learnable)def forward(self, x):return self.scale * self.relu(x) ** 2 + self.biasclass Mlp(nn.Module):""" MLP as used in MetaFormer models, eg Transformer, MLP-Mixer, PoolFormer, MetaFormer baslines and related networks.Mostly copied from timm."""def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0.,bias=False, **kwargs):super().__init__()in_features = dimout_features = out_features or in_featureshidden_features = int(mlp_ratio * in_features)drop_probs = to_2tuple(drop)self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)self.act = act_layer()self.drop1 = nn.Dropout(drop_probs[0])self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)self.drop2 = nn.Dropout(drop_probs[1])def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop1(x)x = self.fc2(x)x = self.drop2(x)return xclass DynamicFilter(nn.Module):def __init__(self, dim, expansion_ratio=2, reweight_expansion_ratio=.25,act1_layer=StarReLU, act2_layer=nn.Identity,bias=False, num_filters=4, size=14, weight_resize=False,**kwargs):super().__init__()size = to_2tuple(size)self.size = size[0]self.filter_size = size[1] // 2 + 1self.num_filters = num_filtersself.dim = dimself.med_channels = int(expansion_ratio * dim)self.weight_resize = weight_resizeself.pwconv1 = nn.Linear(dim, self.med_channels, bias=bias)self.act1 = act1_layer()self.reweight = Mlp(dim, reweight_expansion_ratio, num_filters * self.med_channels)self.complex_weights = nn.Parameter(torch.randn(self.size, self.filter_size, num_filters, 2,dtype=torch.float32) * 0.02)self.act2 = act2_layer()self.pwconv2 = nn.Linear(self.med_channels, dim, bias=bias)def forward(self, x):B, H, W, _ = x.shaperouteing = self.reweight(x.mean(dim=(1, 2))).view(B, self.num_filters,-1).softmax(dim=1)x = self.pwconv1(x)x = self.act1(x)x = x.to(torch.float32)x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')if self.weight_resize:complex_weights = resize_complex_weight(self.complex_weights, x.shape[1],x.shape[2])complex_weights = torch.view_as_complex(complex_weights.contiguous())else:complex_weights = torch.view_as_complex(self.complex_weights)routeing = routeing.to(torch.complex64)weight = torch.einsum('bfc,hwf->bhwc', routeing, complex_weights)if self.weight_resize:weight = weight.view(-1, x.shape[1], x.shape[2], self.med_channels)else:weight = weight.view(-1, self.size, self.filter_size, self.med_channels)x = x * weightx = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')x = self.act2(x)x = self.pwconv2(x)return x
def resize_complex_weight(origin_weight, new_h, new_w):h, w, num_heads = origin_weight.shape[0:3]  # size, w, c, 2origin_weight = origin_weight.reshape(1, h, w, num_heads * 2).permute(0, 3, 1, 2)new_weight = torch.nn.functional.interpolate(origin_weight,size=(new_h, new_w),mode='bicubic',align_corners=True).permute(0, 2, 3, 1).reshape(new_h, new_w, num_heads, 2)return new_weightif __name__ == "__main__":# 如果GPU可用,将模块移动到 GPUinput_size=20device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 输入张量 (batch_size, height, width,channels)x = torch.randn(1, input_size , input_size, 32).to(device)# 初始化 pconv 模块dim = 32block = DynamicFilter(dim=dim,size=input_size)print(block)block = block.to(device)# 前向传播output = block(x)print("输入:", x.shape)print("输出:", output.shape)

http://www.ppmy.cn/embedded/158419.html

相关文章

【Go语言圣经】第五节:函数

第五章:函数 5.1 函数声明 和其它语言类似,Golang 的函数声明包括函数名、形参列表、返回值列表(可省略)以及函数体: func name(parameter-list) (result-list) {/* ... Body ... */ }需要注意的是,函数…

性能优化案例:通过合理设置spark.default.parallelism参数的值来优化PySpark程序的性能

在 PySpark 中,spark.default.parallelism 是一个关键参数,直接影响作业的并行度和资源利用率。 通过合理设置 spark.default.parallelism 并结合数据特征调整,可显著提升 PySpark 作业的并行效率和资源利用率。建议在开发和生产环境中进行多…

F. Ira and Flamenco

题目链接:Problem - F - Codeforces 题目大意:给n,m n个数让从中选m个数满足一下条件: 1.m个数互不相同 2.里面的任意两个数相减的绝对值不能超过m 求这n个数有多少组数据满足。 第一行包含一个整数 t ( 1 ≤ t ≤ 1e4 ) - 测试用例数。 …

跟李沐学AI:视频生成类论文精读(Movie Gen、HunyuanVideo)

Movie Gen:A Cast of Media Foundation Models 简介 Movie Gen是Meta公司提出的一系列内容生成模型,包含了 3.2.1 预训练数据 Movie Gen采用大约 100M 的视频-文本对和 1B 的图片-文本对进行预训练。 图片-文本对的预训练流程与Meta提出的 Emu: Enh…

青少年编程与数学 02-008 Pyhon语言编程基础 05课题、数据类型

青少年编程与数学 02-008 Pyhon语言编程基础 05课题、数据类型 一、数据类型1. 数字类型(Numeric Types)2. 序列类型(Sequence Types)3. 集合类型(Set Types)4. 映射类型(Mapping Type&#xff…

【redis进阶】redis 总结

目录 介绍一下什么是 Redis,有什么特点 Redis 支持哪些数据类型 Redis 数据类型底层的数据结构/编码方式是什么 ZSet 为什么使用跳表,而不是使用红黑树来实现 Redis 的常见应用场景有哪些 怎样测试 Redis 服务器的连通性 如何设置 key 的过期时间 Redis …

React第二十八章(css modules)

css modules 什么是 css modules 因为 React 没有Vue的Scoped,但是React又是SPA(单页面应用),所以需要一种方式来解决css的样式冲突问题,也就是把每个组件的样式做成单独的作用域,实现样式隔离,而css modules就是一种…

LeetCode题练习与总结:最长和谐子序列--594

一、题目描述 和谐数组是指一个数组里元素的最大值和最小值之间的差别 正好是 1 。 给你一个整数数组 nums ,请你在所有可能的 子序列 中找到最长的和谐子序列的长度。 数组的 子序列 是一个由数组派生出来的序列,它可以通过删除一些元素或不删除元素…