探索 PyTorch 中的 ConvTranspose2d 及其转置卷积家族

embedded/2025/3/14 10:50:44/

探索 PyTorch 中的 ConvTranspose2d 及其转置卷积家族

在深度学习领域,尤其是图像处理任务中,卷积神经网络(CNN)扮演着重要角色。而当我们需要在网络中进行上采样(Upsampling)时,转置卷积(Transpose Convolution)就成为了不可或缺的工具。今天,我们以 PyTorch 中的 ConvTranspose2d 为核心,深入探讨它的功能、使用方式,并介绍它的“家族成员”——其他转置卷积相关函数。

什么是 ConvTranspose2d?

ConvTranspose2d 是 PyTorch 中 torch.nn 模块提供的一个二维转置卷积层,也常被称为“反卷积”(Deconvolution),尽管这个名称在学术上并不完全准确。它的本质是通过卷积操作将输入特征图的空间尺寸(宽和高)放大,通常用于上采样任务。

与普通的卷积层(Conv2d)将输入特征图尺寸缩小的功能相反,ConvTranspose2d 的主要作用是:

  1. 上采样:增大特征图的空间分辨率。
  2. 特征恢复:在解码器(如 U-Net 的扩展路径)中恢复细节信息。
  3. 生成任务:在生成对抗网络(GAN)等模型中生成高分辨率输出。

在 U-Net 等分割网络中,ConvTranspose2d 常用于“扩展路径”,通过放大特征图并结合跳跃连接(Skip Connection)逐步重建输入图像的细节。

定义与参数

ConvTranspose2d 的基本定义如下:

python">torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1)
  • in_channels:输入特征图的通道数。
  • out_channels:输出特征图的通道数。
  • kernel_size:卷积核的大小,例如 2(2, 2)
  • stride:卷积核滑动的步幅,默认值为 1,增大步幅会显著放大输出尺寸。
  • padding:输入边缘填充的像素数,默认值为 0。
  • output_padding:调整输出尺寸的额外填充,用于精确控制输出大小。
  • groups:分组卷积的组数,默认值为 1。
  • bias:是否添加偏置项,默认值为 True
  • dilation:卷积核元素之间的间距,默认值为 1。

工作原理

转置卷积的核心思想是将普通卷积的“前向过程”反转。普通卷积通过卷积核滑动和加权求和缩小特征图,而转置卷积则通过在输入特征之间插入零(即“稀疏化”),再应用卷积核,生成更大的输出特征图。这种操作可以看作是对输入特征的“放大重建”。

例如,输入一个 2x2 的特征图,使用 stride=2 的转置卷积,输出尺寸会变为 4x4(具体尺寸还与 kernel_sizepadding 有关)。

使用示例

以下是一个简单的例子:

python">import torch
import torch.nn as nn# 定义一个转置卷积层
upconv = nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=2, stride=2)# 输入张量:1个样本,1个通道,2x2 的特征图
x = torch.tensor([[[[1., 2.],[3., 4.]]]])# 应用转置卷积
y = upconv(x)
print(y.shape)  # 输出:torch.Size([1, 1, 4, 4])
print(y)

在这个例子中,输入 2x2 的特征图被放大为 4x4,具体输出值取决于卷积核的权重。

ConvTranspose2d 的家族成员

转置卷积并非孤立存在,PyTorch 提供了一系列相关函数,统称为“转置卷积家族”。它们针对不同维度和需求设计,以下是几个常见成员:

1. ConvTranspose1d - 一维转置卷积

  • 功能:对一维序列数据进行上采样。
  • 使用场景:适用于时间序列、音频信号等任务。
  • 示例
python">upconv1d = nn.ConvTranspose1d(1, 1, kernel_size=2, stride=2)
x = torch.tensor([[[1., 2., 3.]]])  # 1x3 输入
y = upconv1d(x)
print(y.shape)  # 输出:torch.Size([1, 1, 6])

2. ConvTranspose3d - 三维转置卷积

  • 功能:对三维数据(如体视数据)进行上采样。
  • 使用场景:常用于医学影像(如 CT/MRI)或视频处理。
  • 示例
python">upconv3d = nn.ConvTranspose3d(1, 1, kernel_size=2, stride=2)
x = torch.randn(1, 1, 4, 4, 4)  # 4x4x4 输入
y = upconv3d(x)
print(y.shape)  # 输出:torch.Size([1, 1, 8, 8, 8])

3. 与普通卷积的关系

虽然 ConvTranspose2dConv2d 是“互逆”的概念,但它们并非严格的数学逆操作。转置卷积的权重是可学习的,因此它更像是一种参数化的上采样方法,而非简单的“反卷积”。

在 U-Net 中的应用

在U-Net 代码中(详情见笔者的另一篇博客:深入了解 PyTorch 中的 MaxPool2d 及其池化家族函数),ConvTranspose2d 是扩展路径的核心组件:

python">self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
d3 = self.up3(b)  # 上采样
d3 = torch.cat([d3, e3], dim=1)  # 跳跃连接

补充全部代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass UNet(nn.Module):def __init__(self, in_channels=1, out_channels=2):super(UNet, self).__init__()# 收缩路径self.enc1 = self.conv_block(in_channels, 64)self.enc2 = self.conv_block(64, 128)self.enc3 = self.conv_block(128, 256)self.pool = nn.MaxPool2d(2)# 底部self.bottom = self.conv_block(256, 512)# 扩展路径self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)self.dec3 = self.conv_block(512, 256)  # 拼接后通道数为 256+256=512self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)self.dec2 = self.conv_block(256, 128)self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)self.dec1 = self.conv_block(128, 64)# 输出层self.out_conv = nn.Conv2d(64, out_channels, kernel_size=1)def conv_block(self, in_channels, out_channels):return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0),nn.ReLU(inplace=True))def forward(self, x):# 收缩路径e1 = self.enc1(x)e2 = self.enc2(self.pool(e1))e3 = self.enc3(self.pool(e2))b = self.bottom(self.pool(e3))# 扩展路径d3 = self.up3(b)d3 = torch.cat([d3, e3], dim=1)  # 跳跃连接d3 = self.dec3(d3)d2 = self.up2(d3)d2 = torch.cat([d2, e2], dim=1)d2 = self.dec2(d2)d1 = self.up1(d2)d1 = torch.cat([d1, e1], dim=1)d1 = self.dec1(d1)# 输出out = self.out_conv(d1)return out# 测试代码
if __name__ == "__main__":model = UNet(in_channels=1, out_channels=2)x = torch.randn(1, 1, 572, 572)  # 输入示例:单通道 572x572 图像y = model(x)print(y.shape)  # 输出:torch.Size([1, 2, 388, 388])

在这里,up3 将底部特征图从 512 通道上采样到 256 通道,同时将空间尺寸放大一倍(例如从 56x56 到 112x112)。随后通过跳跃连接与编码路径的特征图 e3 拼接,进一步恢复细节。

U-Net 的这种设计充分利用了转置卷积的上采样能力,结合跳跃连接保留了低层次特征,使模型在图像分割任务中表现出色。

与其他上采样方法的对比

除了转置卷积,PyTorch 还提供了其他上采样方法,如:

  • nn.Upsample:基于插值(如双线性插值)的上采样,计算简单但缺乏学习能力。
  • nn.MaxUnpool2d:基于池化索引的上采样,需与 MaxPool2d 配合使用。
    相比之下,ConvTranspose2d 的优势在于其卷积核是可训练的,可以根据任务需求学习最佳的上采样方式。

总结

ConvTranspose2d 是深度学习中实现上采样的强大工具,广泛应用于图像分割、生成模型等任务。它通过转置卷积操作放大特征图,并结合可学习的参数提供灵活性。它的家族成员(如 ConvTranspose1dConvTranspose3d)进一步扩展了应用场景,覆盖一维到三维数据。

在实际使用中,选择 ConvTranspose2d 还是其他上采样方法,取决于任务需求:如果需要可学习的特征重建,转置卷积是首选;如果只追求简单放大,则插值方法可能更高效。希望这篇博客能让你对 ConvTranspose2d 及其家族有更清晰的认识!

分析ConvTranspose2dtorch.cat 的作用

代码背景

这是 U-Net 模型中“扩展路径”的一部分,具体代码如下:

python">self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
d3 = self.up3(b)  # 上采样
d3 = torch.cat([d3, e3], dim=1)  # 跳跃连接

我们需要搞清楚:

  1. ConvTranspose2d(512, 256, kernel_size=2, stride=2) 如何影响通道数和分辨率。
  2. torch.cat([d3, e3], dim=1) 后的通道数和分辨率变化。

输入假设

假设输入 b 是底部特征图(即 self.bottom 的输出),其形状为 [batch_size, 512, H, W],其中:

  • batch_size 是批量大小(通常为 1 或更多)。
  • 512 是通道数。
  • HW 是特征图的空间分辨率(宽和高,例如 56x56)。

在 U-Net 中,底部特征图通常是经过多次池化(MaxPool2d(2))后的结果,因此分辨率较小。


第一步:ConvTranspose2d 的作用

定义
python">self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
d3 = self.up3(b)
  • in_channels=512:输入通道数为 512。
  • out_channels=256:输出通道数为 256。
  • kernel_size=2:转置卷积核大小为 2x2。
  • stride=2:步幅为 2,表示输出分辨率会放大两倍。
通道数变化
  • 输入 b 的通道数是 512。
  • ConvTranspose2d 通过卷积操作将通道数从 512 减少到 256。
  • 因此,d3 的通道数为 256
分辨率变化

转置卷积的输出尺寸可以通过以下公式计算:

H_out = (H_in - 1) * stride - 2 * padding + kernel_size + output_padding
W_out = (W_in - 1) * stride - 2 * padding + kernel_size + output_padding

默认情况下,padding=0output_padding=0,代入参数:

  • H_in = Hstride = 2kernel_size = 2padding = 0output_padding = 0
  • H_out = (H - 1) * 2 - 2 * 0 + 2 + 0 = 2H - 2 + 2 = 2H
  • 同理,W_out = 2W

所以:

  • 如果输入 b 的分辨率是 H x W(例如 56x56),则 d3 的分辨率变为 2H x 2W(例如 112x112)。
中间结果

经过 d3 = self.up3(b) 后:

  • 通道数:256。
  • 分辨率:2H x 2W(例如 112x112)。
  • d3 的形状为 [batch_size, 256, 2H, 2W]

第二步:torch.cat 的作用

定义
python">d3 = torch.cat([d3, e3], dim=1)  # 跳跃连接
  • torch.cat 是沿着指定维度(这里是 dim=1,即通道维度)拼接两个张量。
  • d3 是上采样后的特征图,形状为 [batch_size, 256, 2H, 2W]
  • e3 是收缩路径中的特征图(来自 self.enc3),形状需要与 d3 的空间分辨率匹配。
e3 的形状

在 U-Net 中,e3 是编码路径中第三层的输出,经过 self.enc3 = self.conv_block(128, 256) 处理:

  • 通道数为 256
  • 分辨率取决于前面的池化操作。假设输入图像是 572x572:
    • e1:经过卷积后为 568x568(因无填充,572 - 3 + 1 = 568)。
    • e2:池化后为 284x284(568 / 2)。
    • e3:池化后为 142x142(284 / 2)。
    • b:池化后为 71x71(142 / 2,假设向下取整)。
    • d3:上采样后为 142x142(71 * 2)。
  • 所以,e3 的分辨率是 142x142,形状为 [batch_size, 256, 142, 142]

由于 d3 的分辨率(2H x 2W,例如 142x142)与 e3 的分辨率匹配,它们可以在通道维度上拼接。

通道数变化
  • d3 的通道数为 256。
  • e3 的通道数为 256。
  • torch.cat([d3, e3], dim=1) 将两个张量的通道数相加:256 + 256 = 512
分辨率变化
  • torch.cat 只在通道维度上操作,不改变空间分辨率。
  • 因此,分辨率保持为 2H x 2W(例如 142x142)。
最终结果

经过 d3 = torch.cat([d3, e3], dim=1) 后:

  • 通道数:512。
  • 分辨率:2H x 2W(例如 142x142)。
  • d3 的形状为 [batch_size, 512, 2H, 2W]

回答疑问

  1. “将通道变小,但是分辨率加倍”

    • 是的,ConvTranspose2d(512, 256, kernel_size=2, stride=2) 将通道数从 512 减小到 256,同时分辨率从 H x W 加倍到 2H x 2W。这是转置卷积的典型行为:通过减少通道数换取更大的空间尺寸。
  2. “然后 cat 一下呢?通道数不变?分辨率怎么变化”

    • 错了,torch.cat([d3, e3], dim=1) 后通道数会变化,从 256 增加到 512(因为拼接了 e3 的 256 个通道)。
    • 分辨率不变,仍然是 2H x 2W,因为 cat 只影响通道维度,不改变宽和高。

总结

  • 初始 b[batch_size, 512, H, W](例如 [1, 512, 71, 71])。
  • 经过 up3d3 变为 [batch_size, 256, 2H, 2W](例如 [1, 256, 142, 142])。
  • 经过 catd3 变为 [batch_size, 512, 2H, 2W](例如 [1, 512, 142, 142])。

通道数先减半(512 -> 256),分辨率加倍(H x W -> 2H x 2W),然后通过跳跃连接拼接后通道数又加倍(256 -> 512),分辨率保持不变。这正是 U-Net 的设计精髓:通过上采样和跳跃连接逐步恢复空间信息,同时融合多尺度特征。

后记

2025年3月13日15点45分于上海,在Grok 3大模型辅助下完成。


http://www.ppmy.cn/embedded/172471.html

相关文章

【CSS3】元婴篇

目录 定位相对定位绝对定位定位居中固定定位堆叠层级 z-index CSS 精灵字体图标下载字体使用字体 垂直对齐方式过渡修饰属性透明度光标类型 定位 作用:灵活的改变盒子在网页中的位置 实现: 定位模式:position边偏移:设置盒子的…

全链条自研可控|江波龙汽车存储“双轮驱动”体系亮相MemoryS 2025

3月12日,MemoryS 2025在深圳盛大开幕,汇聚了存储行业的顶尖专家、企业领袖以及技术先锋,共同探讨存储技术的未来发展方向及其在商业领域的创新应用。江波龙董事长、总经理蔡华波先生受邀出席,并发表了题为《存储商业综合创新》的主…

地理信息系统(ArcGIS)在水文水资源及水环境中的应用:空间数据管理‌、空间分析功能‌、‌可视化表达‌

随着全球工业化和经济的快速发展,水资源短缺、水污染等问题日益严峻,成为制约可持续发展的重大瓶颈。地理信息系统(GIS)以其强大的空间数据管理和分析能力,在水文水资源及水环境的研究和管理中展现出独特优势。本文将深…

仓库管理系统(WMS)系统的基本流程

WMS(仓库管理系统,Warehouse Management System)是用于管理和优化仓库操作的系统,旨在提高库存准确性、提高仓库效率和降低成本。WMS的基本流程包括以下几个主要步骤: 收货(Receiving)&#xff…

OpenCV连续数字识别—可运行验证

前言 ​ 文章开始,瞎说一点其他的东西,真的是很离谱,找了至少两三个小时,就一个简单的需求: 1、利用OpenCV 在Windows进行抓图 2、利用OpenCV 进行连续数字的检测。 3、使用C,Qt 3、将检测的结果显示出来 …

编程考古-VCL跨平台革命:CrossVCL如何让Delphi开发者梦想成真(上)

在软件开发的世界里,有一句老话:“技术的发展总是出乎意料”。对于使用Delphi的开发者而言,这句话从未如此真实。今天,我们将探索一项名为CrossVCL的技术,它不仅重新定义了我们对Visual Component Library(…

OpenAI智能体初探:使用 OpenAI Responses API 在 PDF 中实现检索增强生成(RAG)

大家好,我是大 F,深耕AI算法十余年,互联网大厂技术岗。 知行合一,不写水文,喜欢可关注,分享AI算法干货、技术心得。 欢迎关注《大模型理论和实战》、《DeepSeek技术解析和实战》,一起探索技术的无限可能! 引子 在信息爆炸的时代,从大量 PDF 文档中快速准确地检索信息…

高效图像处理工具:从需求分析到落地实现

高效图像处理工具:从需求分析到落地实现 在现代应用开发中,图像处理是一个不可或缺的功能模块。无论是社交媒体、电子商务还是企业级应用,都需要对图片进行各种处理操作,例如尺寸调整、背景替换、格式转换等。然而,实…