- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
📌你需要解决的疑问:这个代码是否有错?对错与否都请给出你的思考
📌打卡要求:请查找相关资料、逐步推理模型、详细写下你的思考过程
代码如下
# 定义残差单元
def block(x, filters, strides=1, groups=32, conv_shortcut=True): if conv_shortcut: shortcut = Conv2D(filters * 2, kernel_size=(1, 1), strides=strides, padding='same', use_bias=False)(x) # epsilon为BN公式中防止分母为零的值 shortcut = BatchNormalization(epsilon=1.001e-5)(shortcut) else: # identity_shortcut shortcut = x # 三层卷积层 x = Conv2D(filters=filters, kernel_size=(1, 1), strides=1, padding='same', use_bias=False)(x) x = BatchNormalization(epsilon=1.001e-5)(x) x = ReLU()(x) # 计算每组的通道数 g_channels = int(filters / groups) # 进行分组卷积 x = grouped_convolution_block(x, strides, groups, g_channels) x = Conv2D(filters=filters * 2, kernel_size=(1, 1), strides=1, padding='same', use_bias=False)(x) x = BatchNormalization(epsilon=1.001e-5)(x) x = Add()([x, shortcut]) x = ReLU()(x) return x
如果conv_shortcut=False,那么执行“x=Add()…”语句时,通道数不一致的,为什么不会报错呢?
思考
关键点分析:
-
conv_shortcut=False
时的shortcut
- 当
conv_shortcut=False
时,shortcut
直接等于x
,没有经过卷积层处理。也就是说,在x = Add()([x, shortcut])
这一行之前,shortcut
和x
是相同的张量。因此,它们的通道数和空间尺寸应该是相同的,直接相加是没有问题的。 - 因此,执行
Add()
时,x
和shortcut
的维度是匹配的。
- 当
-
conv_shortcut=True
时的shortcut
- 当
conv_shortcut=True
时,shortcut
会通过一个 1x1 卷积进行通道数调整:shortcut = Conv2D(filters * 2, kernel_size=(1, 1), strides=strides, padding='same', use_bias=False)(x)
- 这里
shortcut
的通道数被修改为filters * 2
,使其与主路径输出的通道数一致。 - 如果步长
strides
是大于1的值(例如2),则shortcut
的空间尺寸(高度和宽度)也会发生变化,因此需要保证它们的空间尺寸与主路径的输出相匹配。为此,使用Conv2D
和BatchNormalization
来调整通道数,并保持padding='same'
来避免空间尺寸的变化。
- 这里
- 当
-
Add()
层的作用Add()
层要求输入张量的维度一致。如果conv_shortcut=False
,shortcut
和x
在通道数和空间尺寸上已经是相同的,因此直接相加是不会出现维度不匹配的错误的。- 但是,如果通道数或空间尺寸不一致,Keras 会抛出错误,提示“维度不匹配”或类似错误。
为什么没有报错:
- 当
conv_shortcut=False
时,shortcut
和x
的维度是一样的,因为shortcut = x
,它们在通道数、空间尺寸等维度上都相同,因此Add()
操作没有问题。 Add()
操作本身不会对维度进行任何变换,它只是简单地将两个张量逐元素相加。如果两个张量的维度不一致,Keras 会报错。
结论:
- 在
conv_shortcut=False
的情况下,shortcut
直接等于x
,因此它们的维度是相同的,Add()
操作能够成功执行。 - 只有当
conv_shortcut=True
时,shortcut
会通过卷积进行维度调整,这样就确保了通道数和空间尺寸与主路径的输出一致,避免了维度不匹配的问题。
测试:
为了验证,可以在代码中添加调试语句,打印出 x.shape
和 shortcut.shape
,看看它们在执行 Add()
之前是否匹配。例如:
print("x shape:", x.shape)
print("shortcut shape:", shortcut.shape)
这可以帮助进一步确认 x
和 shortcut
在执行加法之前是否一致。