【深度学习】Pytorch:在 ResNet 中加入注意力机制

ops/2025/1/22 5:33:24/

在这篇教程中,我们将介绍如何在 ResNet 网络中加入注意力机制模块。我们将通过对标准 ResNet50 进行改进,向网络中添加两个自定义的注意力模块,并展示如何实现这一过程。

为什么要加入注意力机制

注意力机制可以帮助神经网络专注于图像中重要的特征区域,从而提高模型的性能。在卷积神经网络中,加入注意力机制能够有效增强特征提取能力,减少冗余信息的干扰,尤其在处理复杂图像时,能够提升网络的表现。

在本教程中,我们将使用一种通用的注意力模块,您可以根据需求自行替换或改进该模块。

代码实现

导入依赖

我们需要以下 PyTorch 库来构建网络:

import torch
import torch.nn as nn
from torchvision import models

定义注意力模块

首先,我们需要定义一个注意力模块。这里我们使用了一个简单的通道注意力机制(如 SE 模块、CBAM 模块等),你可以根据需求选择不同类型的注意力模块。

假设我们已经有一个注意力模块类(AttentionModule),它的结构可以像下面这样:

class AttentionModule(nn.Module):def __init__(self, in_channels):super(AttentionModule, self).__init__()self.conv1 = nn.Conv2d(in_channels, in_channels // 16, kernel_size=1)self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(in_channels // 16, in_channels, kernel_size=1)self.sigmoid = nn.Sigmoid()def forward(self, x):attention = self.conv1(x)attention = self.relu(attention)attention = self.conv2(attention)attention = self.sigmoid(attention)return x * attention

这段代码定义了一个简单的注意力模块。它通过两个卷积层和一个 Sigmoid 函数来生成一个通道注意力映射,并通过该映射加权输入特征图。

构建 ResNet 与注意力机制集成的模型

现在我们将创建一个新的模型类 ResNetWithAttention,该模型继承自 nn.Module,并将注意力模块插入到 ResNet 的关键位置。在这个示例中,我们将注意力模块插入到网络的卷积层输出之后,并在最后一层卷积层后再次插入。

class ResNetWithAttention(nn.Module):def __init__(self, attention_cls, pretrained=True):super(ResNetWithAttention, self).__init__()# 使用预训练的 ResNet50self.base_model = models.resnet50(pretrained=pretrained)# 创建注意力模块self.attention_layer1 = attention_cls(64)  # 第一层卷积后self.attention_layer2 = attention_cls(2048)  # 最后一层卷积后def forward(self, x):# ResNet50的前向传播过程x = self.base_model.conv1(x)  # 初始卷积层x = self.base_model.bn1(x)  # 批归一化x = self.base_model.relu(x)  # 激活函数# 第一个注意力模块:第一层卷积后x = self.attention_layer1(x)# 最大池化层x = self.base_model.maxpool(x)# ResNet的残差层x = self.base_model.layer1(x)x = self.base_model.layer2(x)x = self.base_model.layer3(x)x = self.base_model.layer4(x)# 第二个注意力模块:最后一层卷积后x = self.attention_layer2(x)# 平均池化x = self.base_model.avgpool(x)# 展平并通过全连接层x = torch.flatten(x, 1)x = self.base_model.fc(x)return x

在这个模型中,我们通过 attention_cls 参数动态地将任何类型的注意力模块传入模型。模型首先使用基础的 ResNet50 结构,之后我们将自定义的注意力模块应用到两个关键位置:一个是在第一层卷积之后,另一个是在最后的卷积层之后。

训练模型

使用该模型的训练过程与标准的 ResNet 模型相同。你可以像使用普通的 ResNet 模型一样训练和评估 ResNetWithAttention。下面是训练的一般流程:

# 示例:初始化模型并进行训练
attention_cls = AttentionModule  # 可以替换为其他类型的注意力模块
model = ResNetWithAttention(attention_cls)# 选择优化器和损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()# 假设 train_loader 是数据加载器
for epoch in range(num_epochs):model.train()for inputs, labels in train_loader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')

可加入的注意力模块

通道注意力模块

  • SE (Squeeze-and-Excitation) 模块:最经典的通道注意力模块,使用全局平均池化后生成通道级注意力,通过全连接层建模通道之间的关系。

    class SEBlock(nn.Module):def __init__(self, in_channels, reduction=16):super(SEBlock, self).__init__()# 定义第一个全连接层,将输入通道数压缩为 in_channels // reductionself.fc1 = nn.Linear(in_channels, in_channels // reduction)# 定义第二个全连接层,将通道数恢复为原始输入通道数self.fc2 = nn.Linear(in_channels // reduction, in_channels)# 定义Sigmoid激活函数,用于生成注意力权重self.sigmoid = nn.Sigmoid()def forward(self, x):# 获取输入张量的 batch_size 和 channelsbatch_size, channels, _, _ = x.size()# 对输入张量的空间维度(高度和宽度)进行全局平均池化squeeze = torch.mean(x, dim=(2, 3))# 通过第一个全连接层进行通道压缩squeeze = self.fc1(squeeze)# 通过ReLU激活函数和第二个全连接层进行通道扩展squeeze = self.fc2(F.relu(squeeze))# 使用Sigmoid生成注意力权重,并调整形状以匹配输入张量的维度attention = self.sigmoid(squeeze).view(batch_size, channels, 1, 1)# 将注意力权重应用到输入张量上,进行通道加权return x * attention
    
  • ECA (Efficient Channel Attention) 模块:通过 1D 卷积建模通道间的依赖关系,减少了计算量,提升了效率

    class ECABlock(nn.Module):def __init__(self, channels, kernel_size=3):super(ECABlock, self).__init__()# 定义1D卷积层,用于学习通道间的注意力权重self.conv = nn.Conv1d(1, 1, kernel_size, padding=kernel_size // 2, bias=False)def forward(self, x):# 获取输入张量的 batch_size 和 channelsbatch_size, channels, _, _ = x.size()# 对输入张量的空间维度(高度和宽度)进行全局平均池化,并调整形状y = F.adaptive_avg_pool2d(x, 1).view(batch_size, channels, 1)# 调整形状以适配1D卷积的输入格式y = y.view(batch_size, 1, channels)# 通过1D卷积层学习通道间的注意力权重y = self.conv(y)# 使用Sigmoid激活函数生成注意力权重y = torch.sigmoid(y)# 将注意力权重应用到输入张量上,进行通道加权return x * y.view(batch_size, channels, 1, 1).expand_as(x)
    

空间注意力模块

  • CBAM (Convolutional Block Attention Module) 模块:结合了通道注意力和空间注意力,首先进行通道注意力加权,然后通过空间卷积生成空间注意力。

    class CBAM(nn.Module):def __init__(self, in_channels, reduction=16):super(CBAM, self).__init__()# 通道注意力模块(SEBlock),用于学习通道间的注意力权重self.channel_attention = SEBlock(in_channels, reduction)# 空间注意力模块,使用1x1卷积核学习空间注意力权重self.spatial_attention = nn.Conv2d(2, 1, kernel_size=7, padding=3)def forward(self, x):# 应用通道注意力模块,对输入特征进行通道加权x = self.channel_attention(x)# 计算输入特征在通道维度上的平均值avg_out = torch.mean(x, dim=1, keepdim=True)# 计算输入特征在通道维度上的最大值max_out, _ = torch.max(x, dim=1, keepdim=True)# 将平均值和最大值拼接在一起spatial_out = torch.cat([avg_out, max_out], dim=1)# 通过空间注意力模块学习空间注意力权重spatial_out = self.spatial_attention(spatial_out)# 使用Sigmoid激活函数生成空间注意力权重spatial_attention = torch.sigmoid(spatial_out)# 将空间注意力权重应用到输入特征上,进行空间加权return x * spatial_attention
    
  • Coordinate Attention 模块:通过空间坐标信息提升特征建模能力,增强空间特征表达。

    class CoordinateAttention(nn.Module):def __init__(self, in_channels, reduction=16):super(CoordinateAttention, self).__init__()# 定义1x1卷积层,用于压缩通道数self.fc = nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1)# 定义1x1卷积层,用于恢复通道数self.fc_out = nn.Conv2d(in_channels // reduction, in_channels, kernel_size=1)def forward(self, x):# 获取输入张量的 batch_size、channels、height 和 widthbatch_size, channels, height, width = x.size()# 对输入张量的空间维度(高度和宽度)进行全局平均池化avg_out = torch.mean(x, dim=[2, 3], keepdim=True)# 对输入张量的空间维度(高度和宽度)进行全局最大池化max_out = torch.amax(x, dim=[2, 3], keepdim=True)# 通过1x1卷积层压缩通道数avg_out = self.fc(avg_out)max_out = self.fc(max_out)# 通过1x1卷积层恢复通道数avg_out = self.fc_out(avg_out)max_out = self.fc_out(max_out)# 将平均池化和最大池化的结果相加,并应用到输入张量上out = x * (avg_out + max_out)return out
    

双重注意力模块

  • Dual Attention 模块:结合了通道和空间的双重注意力机制,增强了特征的表征能力。

    class DualAttentionBlock(nn.Module):def __init__(self, in_channels, reduction=16):super(DualAttentionBlock, self).__init__()# 通道注意力模块(SEBlock),用于学习通道间的注意力权重self.channel_attention = SEBlock(in_channels, reduction)# 空间注意力模块(CBAM),用于学习空间上的注意力权重self.spatial_attention = CBAM(in_channels, reduction)def forward(self, x):# 应用通道注意力模块,对输入特征进行通道加权x = self.channel_attention(x)# 应用空间注意力模块,对输入特征进行空间加权x = self.spatial_attention(x)return x
    

全局依赖建模模块

  • Non-local 模块:通过自注意力机制建模全局依赖关系,提升对长距离特征的建模能力。

    class NonLocalBlock(nn.Module):def __init__(self, in_channels):super(NonLocalBlock, self).__init__()# 输入通道数self.in_channels = in_channels# 中间通道数,通常为输入通道数的一半self.inter_channels = in_channels // 2# 定义1x1卷积层,用于生成查询(query)特征self.query_conv = nn.Conv2d(in_channels, self.inter_channels, kernel_size=1)# 定义1x1卷积层,用于生成键(key)特征self.key_conv = nn.Conv2d(in_channels, self.inter_channels, kernel_size=1)# 定义1x1卷积层,用于生成值(value)特征self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)# 定义Softmax函数,用于计算注意力权重self.softmax = nn.Softmax(dim=-1)def forward(self, x):# 获取输入张量的 batch_size、通道数、高度和宽度batch_size, C, H, W = x.size()# 通过查询卷积层生成查询特征,并调整形状query = self.query_conv(x).view(batch_size, self.inter_channels, -1)# 通过键卷积层生成键特征,并调整形状key = self.key_conv(x).view(batch_size, self.inter_channels, -1)# 通过值卷积层生成值特征,并调整形状value = self.value_conv(x).view(batch_size, C, -1)# 计算查询特征和键特征的相似度(亲和矩阵)affinity = torch.bmm(query.transpose(1, 2), key)# 使用Softmax计算注意力权重attention = self.softmax(affinity)# 将注意力权重应用到值特征上,得到加权输出out = torch.bmm(value, attention.transpose(1, 2))# 调整输出形状以匹配输入张量的维度out = out.view(batch_size, C, H, W)# 将加权输出与输入张量相加,实现残差连接return out + x
    
  • Attention U-Net 模块:在 U-Net 结构中引入注意力模块,适用于图像分割任务,能够自适应地选择重要区域进行特征增强。

    class AttentionGate(nn.Module):def __init__(self, in_channels):super(AttentionGate, self).__init__()# 定义门控通道数和中间通道数gating_channels = in_channelsinter_channels = in_channels // 2# 定义1x1卷积层,用于处理输入特征self.conv1 = nn.Conv2d(in_channels, inter_channels, kernel_size=1)# 定义1x1卷积层,用于处理门控特征self.conv2 = nn.Conv2d(gating_channels, inter_channels, kernel_size=1)# 定义1x1卷积层,用于生成注意力权重self.psi = nn.Conv2d(inter_channels, 1, kernel_size=1)# 定义Sigmoid激活函数,用于生成注意力权重self.sigmoid = nn.Sigmoid()def forward(self, x):# 门控信号与输入特征相同gating = x# 对输入特征进行1x1卷积x1 = self.conv1(x)# 对门控信号进行1x1卷积x2 = self.conv2(gating)# 将两个卷积结果相加,并通过ReLU激活函数attention = self.sigmoid(self.psi(F.relu(x1 + x2)))# 将注意力权重应用到输入特征上,进行加权return x * attention
    

总结

通过将注意力模块集成到 ResNet 中,我们能够增强模型对重要特征的关注,从而提高性能。你可以根据需要选择不同的注意力机制,并在模型中任意位置插入这些模块。


http://www.ppmy.cn/ops/152108.html

相关文章

PyBroker:利用 Python 和机器学习助力算法交易

PyBroker:利用 Python 和机器学习助力算法交易 你是否希望借助 Python 和机器学习的力量来优化你的交易策略?那么你需要了解一下 PyBroker!这个 Python 框架专为开发算法交易策略而设计,尤其关注使用机器学习的策略。借助 PyBrok…

RPA编程实践:Electron简介

文章目录 前言使用Electron构建桌面应用程序什么是Electron?为什么选择Electron?如何使用Electron实现上述想法?1. 创建基本的Electron应用2. 配置BrowserWindow3. 定制化你的应用4. 打包与分发 前言 Electron,用官网的话说&…

常用集合-数据结构-MySql

目录 java核心: 常用集合与数据结构: 单例集合: 双列集合: 线程安全的集合: ConcurrentHashMap集合: HashTable集合: CopyOnWriteArrayList集合: CopyOnWriteArraySet集合: ConcurrentLinkedQueue队列: ConcurrentSkipListMap和ConcurrentSkipListSet&…

01.04、回文排序

01.04、[简单] 回文排序 1、题目描述 给定一个字符串,编写一个函数判定其是否为某个回文串的排列之一。回文串是指正反两个方向都一样的单词或短语。排列是指字母的重新排列。回文串不一定是字典当中的单词。 2、解题思路 回文串的特点: 一个回文串在…

adb常用指令(完整版)

1、adb devices 查看是否连接到设备 2、adb install [-r] [-s] 安装app,-r强制,-s安装sd卡上 3、adb uninstall [-k] 卸载app,-k保留配置和参数 4、adb push 把本地文件上传设备 5、adb pull 下载文件到本地 6、cd D:\sdk\platform-tool…

html,css,js的粒子效果

这段代码实现了一个基于HTML5 Canvas的高级粒子效果&#xff0c;用户可以通过鼠标与粒子进行交互。下面是对代码的详细解析&#xff1a; HTML部分 使用<!DOCTYPE html>声明文档类型。<html>标签内包含了整个网页的内容。<head>部分定义了网页的标题&#x…

HTML语言的计算机基础

HTML语言的计算机基础 引言 在当今信息技术迅猛发展的时代&#xff0c;网页设计和开发已成为计算机科学中不可或缺的一部分。而HTML&#xff08;超文本标记语言&#xff09;作为构建网页的基础语言&#xff0c;承载着网页上所有内容的结构&#xff0c;帮助开发者创建和展示信…

如何解析返回的快递费用数据?

解析返回的快递费用数据是使用 API 的关键步骤之一。解析数据时&#xff0c;需要根据返回的 JSON 格式提取有用的信息&#xff0c;并进行适当的处理。以下是一个完整的示例&#xff0c;展示如何解析 1688 item_fee 接口返回的快递费用数据。 一、返回数据的结构 在调用 1688 的…