【论文阅读笔记】FcaNet: Frequency Channel Attention Networks(2021/7/23)

news/2025/2/28 7:23:20/

目录

1 引言

2 方法

1 DCT和通道注意力

2 多光谱通道注意力(Multi-Spectral Channel Attention)

3 频率成分的选择标准

3 实验

4 结论



论文题目:FcaNet: Frequency Channel Attention Networks(频率通道注意力网络)

论文地址:https://arxiv.org/pdf/2012.11879

论文代码:https://github.com/cfzd/FcaNet

【摘要】

注意力机制,尤其是通道注意力,在计算机视觉领域取得了巨大的成功。许多工作专注于如何设计高效的通道注意力机制,而忽略了一个基本问题,即通道注意力机制使用标量来表示通道,这由于大量信息丢失而变得困难。在这项工作中,我们从不同的角度出发,将信道表示问题视为使用频率分析的压缩过程。基于频域分析,我们从数学上证明了传统的全局平均池化是频域特征分解的特例。通过证明,我们自然地将通道注意力机制的压缩推广到频域,并提出了我们的多光谱通道注意力方法,称为FcaNetFcaNet简单但有效。在现有的通道注意力方法中,我们可以在计算中改变几行代码来实现我们的方法。此外,与其他通道注意力方法相比,该方法在图像分类、目标检测和实例分割任务上取得了先进的结果。我们的方法可以在相同的参数数量和相同的计算成本的情况下,始终如一地优于基线SENet。

1 引言

每朵玫瑰都有刺。

通常,由于计算开销有限,通道注意力方法的核心步骤是为每个通道使用一个标量来进行计算全局平均池(GAP)由于其简单高效而成为深度学习事实上的标准选择。然而,GAP的简单性使得很难很好地捕获各种输入的复杂信息

一个新的视角:频域

与以往的工作不同,我们将通道的标量表示视为一个压缩问题。在尽可能保持整个通道的表示能力的同时,通道的信息应采用标量进行紧凑编码。

如何有效地压缩具有标量的

在信道注意力机制中使用离散余弦变换DCT来压缩信道,原因如下:

DCT是信号处理中广泛使用的数据压缩方法,尤其是在数字图像和视频中。
DCT可以用元素乘法实现,并且是可微的。通过这种方式,它可以很容易地集成到CNN中。

DCT可以看作是GAP的推广。

本文在上述讨论的基础上,进一步提出了一种简单、新颖、有效的多光谱通道注意力(multi-spectral channel attentionMSCA框架。

总的来说,本文的主要贡献可以概括如下:

 我们将通道注意力视为一个压缩问题,并在通道注意力中引入DCT。然后,我们证明了传统GAP是DCT的特例。基于这个证明,我们在频域中推广了通道注意力,并提出了基于多光谱通道注意力框架的方法,称为FcaNet

 我们提出了三种频率成分选择标准(Lf基于低频的选择、ts两步选择法、NAS神经架构搜索选择),并提出了基于多光谱通道注意力框架的方法来实现FcaNet

 大量实验表明,本文提出的方法在ImageNet和COCO数据集上都取得了当前最好的结果,并且具有与SENet相同的计算开销

在ImageNet上的结果如图1所示。

2 方法

本节首先回顾DCT和通道注意力的公式。然后,基于这些工作,我们详细阐述了多光谱通道注意力框架的推导。同时,结合多光谱信道注意力框架,提出了三种频率分量选择方法。

1 DCT和通道注意力

 DCT

离散余弦变换(DCT)通常二维DCT的基函数见公式1:

DCT变换就是将原图像和DCT基函数进行内积运算。2D DCT见公式2:

逆2D DCT见公式3:

 通道注意力

通道注意力机制在CNN中得到广泛应用。它使用标量来表示和评估每个通道的重要性。假设X ∈ 是网络中的图像特征张量,C是通道数,H是特征的高度,W是特征的宽度。我们将通道注意力中的标量表示视为一个压缩问题,因为它必须表示整个通道,而只能使用一个标量。这样,注意力机制可以写为如下,见公式4:

式中:att为注意力向量,sigmoid为Sigmoid函数,fc表示全连接层或一维卷积等映射函数,压缩为压缩方法。在得到所有C通道的注意力向量后,输入X的每个通道都被相应的注意力值缩放,见公式5:

通常,全局平均池化是事实上的压缩方法,因为它的简单性和有效性。还有全局最大池化和全局标准差池化等压缩方法。

2 多光谱通道注意力(Multi-Spectral Channel Attention)

 证明:GAP是2D-DCT的特例

DCT可以看作是输入的加权和。我们进一步证明GAP实际上是2D DCT的一个特例。

定理1。GAP是二维DCT的特例,其结果正比于二维DCT的最低频率成分。

将2维DCT公式中的h和w都置为0,得到公式6:

结论:在通道注意机制中使用GAP意味着只保留最低频率的信息。

推广:除了低频信息以外,其他频率的所有分量也表示信道的有用信息不应被遗弃,所以作者提出在注意力机制中将GAP推广到2D DCT,并2D DCT的多个频率分量压缩更多信息包括最低频率分量,即GAP。

 多光谱通道注意力MSCA模块

①分割

将输入X沿通道维度分割成许多部分[ X0 , X1 , · · · , Xn - 1]。

其中,,C应被n整除。

②压缩

对于每个部分,分配一个对应的2D DCT频率分量,2D DCT结果可以作为通道注意力的压缩结果。见公式7:

③拼接

整个压缩向量可以通过拼接得到,见公式8:

式中:为获得的多光谱向量。

MSCA

整个多光谱通道注意力框架可以写为公式9:

我们方法的总体说明见图2。见图2。现有通道注意力和多光谱通道注意力的说明。为了简单起见,2D DCT指标以一维格式表示。我们可以看到,我们的方法使用多个频率成分与所选择的DCT基,而SENet只使用GAP在通道注意力。最好在彩图观看。

代码如下:

python">import math
import torch
import torch.nn as nndef get_freq_indices(method):assert method in ['top1','top2','top4','top8','top16','top32','bot1','bot2','bot4','bot8','bot16','bot32','low1','low2','low4','low8','low16','low32']num_freq = int(method[3:])if 'top' in method:all_top_indices_x = [0,0,6,0,0,1,1,4,5,1,3,0,0,0,3,2,4,6,3,5,5,2,6,5,5,3,3,4,2,2,6,1]all_top_indices_y = [0,1,0,5,2,0,2,0,0,6,0,4,6,3,5,2,6,3,3,3,5,1,1,2,4,2,1,1,3,0,5,3]mapper_x = all_top_indices_x[:num_freq]mapper_y = all_top_indices_y[:num_freq]elif 'low' in method:all_low_indices_x = [0,0,1,1,0,2,2,1,2,0,3,4,0,1,3,0,1,2,3,4,5,0,1,2,3,4,5,6,1,2,3,4]all_low_indices_y = [0,1,0,1,2,0,1,2,2,3,0,0,4,3,1,5,4,3,2,1,0,6,5,4,3,2,1,0,6,5,4,3]mapper_x = all_low_indices_x[:num_freq]mapper_y = all_low_indices_y[:num_freq]elif 'bot' in method:all_bot_indices_x = [6,1,3,3,2,4,1,2,4,4,5,1,4,6,2,5,6,1,6,2,2,4,3,3,5,5,6,2,5,5,3,6]all_bot_indices_y = [6,4,4,6,6,3,1,4,4,5,6,5,2,2,5,1,4,3,5,0,3,1,1,2,4,2,1,1,5,3,3,3]mapper_x = all_bot_indices_x[:num_freq]mapper_y = all_bot_indices_y[:num_freq]else:raise NotImplementedErrorreturn mapper_x, mapper_yclass MultiSpectralAttentionLayer(torch.nn.Module):def __init__(self, channel, dct_h, dct_w, reduction = 16, freq_sel_method = 'top16'):super(MultiSpectralAttentionLayer, self).__init__()self.reduction = reductionself.dct_h = dct_hself.dct_w = dct_wmapper_x, mapper_y = get_freq_indices(freq_sel_method)self.num_split = len(mapper_x)mapper_x = [temp_x * (dct_h // 7) for temp_x in mapper_x] mapper_y = [temp_y * (dct_w // 7) for temp_y in mapper_y]# make the frequencies in different sizes are identical to a 7x7 frequency space# eg, (2,2) in 14x14 is identical to (1,1) in 7x7self.dct_layer = MultiSpectralDCTLayer(dct_h, dct_w, mapper_x, mapper_y, channel)self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def forward(self, x):n,c,h,w = x.shapex_pooled = xif h != self.dct_h or w != self.dct_w:x_pooled = torch.nn.functional.adaptive_avg_pool2d(x, (self.dct_h, self.dct_w))# If you have concerns about one-line-change, don't worry.   :)# In the ImageNet models, this line will never be triggered. # This is for compatibility in instance segmentation and object detection.y = self.dct_layer(x_pooled)y = self.fc(y).view(n, c, 1, 1)return x * y.expand_as(x)class MultiSpectralDCTLayer(nn.Module):"""Generate dct filters"""def __init__(self, height, width, mapper_x, mapper_y, channel):super(MultiSpectralDCTLayer, self).__init__()assert len(mapper_x) == len(mapper_y)assert channel % len(mapper_x) == 0self.num_freq = len(mapper_x)# fixed DCT initself.register_buffer('weight', self.get_dct_filter(height, width, mapper_x, mapper_y, channel))# fixed random init# self.register_buffer('weight', torch.rand(channel, height, width))# learnable DCT init# self.register_parameter('weight', self.get_dct_filter(height, width, mapper_x, mapper_y, channel))# learnable random init# self.register_parameter('weight', torch.rand(channel, height, width))# num_freq, h, wdef forward(self, x):assert len(x.shape) == 4, 'x must been 4 dimensions, but got ' + str(len(x.shape))# n, c, h, w = x.shapex = x * self.weightresult = torch.sum(x, dim=[2,3])return resultdef build_filter(self, pos, freq, POS):result = math.cos(math.pi * freq * (pos + 0.5) / POS) / math.sqrt(POS) if freq == 0:return resultelse:return result * math.sqrt(2)def get_dct_filter(self, tile_size_x, tile_size_y, mapper_x, mapper_y, channel):dct_filter = torch.zeros(channel, tile_size_x, tile_size_y)c_part = channel // len(mapper_x)for i, (u_x, v_y) in enumerate(zip(mapper_x, mapper_y)):for t_x in range(tile_size_x):for t_y in range(tile_size_y):dct_filter[i * c_part: (i+1)*c_part, t_x, t_y] = self.build_filter(t_x, u_x, tile_size_x) * self.build_filter(t_y, v_y, tile_size_y)return dct_filter

3 频率成分的选择标准

 FcaNet-LF是指含有低频成分的FcaNet。这样,选择频率成分的第一个标准就是只选择低频成分

 FcaNet-TS是指FcaNet两步选择方案中选择组件。

主要思想是首先确定每个频率分量的重要性,然后研究使用不同数量的频率分量的效果。也就是说,我们分别评估通道注意中每个频率成分的结果。最后,我们根据评估结果选择Top-k最高性能频率组件。

 FcaNet-NAS是指具有搜索组件FcaNet

对于这个标准,我们使用神经架构搜索来搜索通道的最佳频率成分。对于每个部件Xi,一组连续变量被分配给搜索组件。该部分的频率成分可写为公式10:

3 实验

图3。ImageNet上的Top-1准确率分别使用通道注意力中的不同频率成分。

图4。不同部件数量下的Top1准确率。由于FcaNet-NAS自动搜索并确定频率成分,故不纳入本实验。

图5。与完全可学习通道注意的比较。FR表示随机初始化的固定张量,LR表示随机初始化的学习张量,LD表示DCT初始化的学习张量,FD表示DCT初始化的固定张量。对于随机初始化的设置,显示了误差条。

表1。ImageNet上不同注意力方法的比较。除AANet没有官方代码外,所有结果均采用相同的训练设置进行再现和训练。

表2 .不同方法在COCO val 2017上的目标检测结果。

表3 . COCO val 2017上使用Mask R-CNN的不同方法的实例分割结果。

4 结论

【结论】在本文中,我们研究了通道注意力的一个基本问题,即如何表示通道,并将这个问题看作一个压缩过程。我们证明了GAP是DCT的一个特例,并提出了带有多光谱注意力模块的FcaNet,该网络在频域上推广了现有的通道注意力机制。同时,我们在多光谱框架中探索了不同频率成分的组合,并提出了3个频率成分选择标准。在相同的参数数量和计算成本下,我们的方法能够一致地优于SENet。与其他通道注意力方法相比,我们在图像分类、目标检测和实例分割上也取得了最先进的性能。此外,FcaNet简单有效。我们的方法可以在现有的通道注意力方法的基础上只进行几行代码更改就可以实现

至此,本文的内容就结束啦。


http://www.ppmy.cn/news/1575439.html

相关文章

Docker数据卷操作实战

什么是数据卷 数据卷 是一个可供一个或多个容器使用的特殊目录,它绕过 UFS,可以提供很多有用的特性: 数据卷 可以在容器之间共享和享用对 数据卷 的修改立马生效对 数据卷 的更新,不会影响镜像数据卷 默认会一直存在,即时容器被…

Pytorch实现之浑浊水下图像增强

简介 简介:这也是一篇非常适合GAN小白们上手的架构文章!提出了一种基于GAN的水下图像增强网络。这种网络与其他架构类似,生成器是卷积+激活函数+归一化+残差结构的组成,鉴别器是卷积+激活函数+归一化以及全连接层。损失函数是常用的均方误差、感知损失和对抗损失三部分。 …

用PyTorch从零构建 DeepSeek R1:模型架构和分步训练详解

DeepSeek R1 的完整训练流程核心在于,在其基础模型 DeepSeek V3 之上,运用了多种强化学习策略。 本文将从一个可本地运行的基础模型起步,并参照其技术报告,完全从零开始构建 DeepSeek R1,理论结合实践,逐步…

字段对比清洗

import pandas as pd import psycopg2 from psycopg2 import sql# 数据库连接配置 DB_CONFIG {"host": "","user": "","password": "","dbname": "","port": , }def get_excel_fi…

Spring Boot项目集成Redisson 原始依赖与 Spring Boot Starter 的流程

Redisson 是一个高性能的 Java Redis 客户端,提供了丰富的分布式工具集,如分布式锁、Map、Queue 等,帮助开发者简化 Redis 的操作。在集成 Redisson 到项目时,开发者通常有两种选择: 使用 Redisson 原始依赖。使用 Re…

STM32 微控制器库RCC_OscInitTypeDef结构参数介绍

目录 1. 结构体定义2. 结构体成员说明(1) OscillatorType(2) HSEState(3) LSEState(4) HSIState(5) HSICalibrationValue(6) LSIState(7) PLL 3. 使用步骤(1) 定义结构体(2) 配置结构体成员(3) 调用 HAL 初始化函数 4. 示例代码5. 注意事项(1) 时钟源的选择(2) 校准值(3) 时钟配…

RuntimeWarning: invalid value encountered in scalar power在进行标量的幂运算时遇到了无效值

year_profit ((profit / initial_cash) ** (1 / yy) - 1) * 100 RuntimeWarning: invalid value encountered in scalar power 这个警告表示在执行标量幂运算 ((profit / initial_cash) ** (1 / yy) - 1) * 100 时遇到了无效值。常见的引发原因及解决办法如下: ###…

WPF12-MVVM

目录 1. 什么是MVVM2. 实现简单MVVM2.1. Part 12.2. Part 2 1. 什么是MVVM MVVM 是 Model-View-ViewModel 的缩写,是一种用于构建用户界面的设计模式,是一种简化用户界面的事件驱动编程方式。 MVVM 的目标是实现用户界面和业务逻辑之间的彻底分离&…