Pytorch从零开始实战16

news/2024/11/15 6:01:49/

Pytorch从零开始实战——ResNeXt-50算法的思考

本系列来源于365天深度学习训练营

原作者K同学

对于上次ResNeXt-50算法,我们同样有基于TensorFlow的实现。具体代码如下。

引入头文件

import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Input, Dense, Dropout, Conv2D, MaxPool2D, Flatten, GlobalAvgPool2D, concatenate, \
BatchNormalization, Activation, Add, ZeroPadding2D, Lambda
from tensorflow.keras.layers import ReLU
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import LearningRateScheduler
from tensorflow.keras.models import Model

分组卷积模块

# 定义分组卷积
def grouped_convolution_block(init_x, strides, groups, g_channels):group_list = []# 分组进行卷积for c in range(groups):# 分组取出数据x = Lambda(lambda x: x[:, :, :, c * g_channels:(c + 1) * g_channels])(init_x)# 分组进行卷积x = Conv2D(filters=g_channels, kernel_size=(3, 3),strides=strides, padding='same', use_bias=False)(x)# 存入listgroup_list.append(x)# 合并list中的数据group_merage = concatenate(group_list, axis=3)x = BatchNormalization(epsilon=1.001e-5)(group_merage)x = ReLU()(x)return x

残差单元

# 定义残差单元
def block(x, filters, strides=1, groups=32, conv_shortcut=True):if conv_shortcut:shortcut = Conv2D(filters * 2, kernel_size=(1, 1), strides=strides, padding='same', use_bias=False)(x)# epsilon为BN公式中防止分母为零的值shortcut = BatchNormalization(epsilon=1.001e-5)(shortcut)else:# identity_shortcutshortcut = x# 三层卷积层x = Conv2D(filters=filters, kernel_size=(1, 1), strides=1, padding='same', use_bias=False)(x)x = BatchNormalization(epsilon=1.001e-5)(x)x = ReLU()(x)# 计算每组的通道数g_channels = int(filters / groups)# 进行分组卷积x = grouped_convolution_block(x, strides, groups, g_channels)x = Conv2D(filters=filters * 2, kernel_size=(1, 1), strides=1, padding='same', use_bias=False)(x)x = BatchNormalization(epsilon=1.001e-5)(x)x = Add()([x, shortcut])x = ReLU()(x)return x

堆叠残差单元

# 堆叠残差单元
def stack(x, filters, blocks, strides, groups=32):# 每个stack的第一个block的残差连接都需要使用1*1卷积升维x = block(x, filters, strides=strides, groups=groups)for i in range(blocks):x = block(x, filters, groups=groups, conv_shortcut=False)return x

网络搭建

# 定义ResNext50(32*4d)网络
def ResNext50(input_shape, num_classes):inputs = Input(shape=input_shape)# 填充3圈0,[224,224,3]->[230,230,3]x = ZeroPadding2D((3, 3))(inputs)x = Conv2D(filters=64, kernel_size=(7, 7), strides=2, padding='valid')(x)x = BatchNormalization(epsilon=1.001e-5)(x)x = ReLU()(x)# 填充1圈0x = ZeroPadding2D((1, 1))(x)x = MaxPool2D(pool_size=(3, 3), strides=2, padding='valid')(x)# 堆叠残差结构x = stack(x, filters=128, blocks=2, strides=1)x = stack(x, filters=256, blocks=3, strides=2)x = stack(x, filters=512, blocks=5, strides=2)x = stack(x, filters=1024, blocks=2, strides=2)# 根据特征图大小进行全局平均池化x = GlobalAvgPool2D()(x)x = Dense(num_classes, activation='softmax')(x)# 定义模型model = Model(inputs=inputs, outputs=x)return model

对于残差单元中的代码,提出一个问题:当conv_shortcut=False的时候,在执行Add操作时,理论上通道数不一致,为什么代码不报错?
在这里插入图片描述
答:这主要是跟下面堆叠残差单元的代码有关系,每个stack第一轮总会令conv_shortcut为True,使得x通道数进行扩展,而后面循环的时候传入的filters还是这个函数的实参,没有发生变化,但由于conv_shortcut为False,此时shortcut的通道数是与上面的x一致,所以在Add的时候,代码不会报错。

def stack(x, filters, blocks, strides, groups=32):# 每个stack的第一个block的残差连接都需要使用1*1卷积升维x = block(x, filters, strides=strides, groups=groups)for i in range(blocks):x = block(x, filters, groups=groups, conv_shortcut=False)return x

本文只是对ResNeXt-50算法的部分代码进行思考,学习过程中需要积极思考与探索,以提高能力和解决问题。


http://www.ppmy.cn/news/1298924.html

相关文章

thinkphp6入门(15)-- 模型动态构建查询条件

背景 我使用thinkphp6的模型写数据库查询,有多个where条件,但是不确定是否需要添加某个where条件,怎么才能动态得生成查询 链式查询 在ThinkPHP 6中,可以使用链式查询方法来动态地构建查询条件。可以根据参数的值来决定是否添加…

JavaScript 地址信息与页面跳转

在JavaScript中,你可以使用各种方法来处理地址信息并进行页面跳转。以下是一些常见的方法: 1.使用window.location对象: window.location对象包含了当前窗口的URL信息,并且可以用来进行页面跳转。 * 获取URL的某一部分&#xf…

【python基础教程】print输出函数和range()函数的正确使用方式

嗨喽,大家好呀~这里是爱看美女的茜茜呐 print()有多个参数,参数个数不固定。 有四个关键字参数(sep end file flush),这四个关键字参数都有默认值。 print作用是将objects的内容输出到file中,objects中的…

前端面试 -- vue系列

Vue系列 1. vue理解:2. SPA(单页面应用理解)3. vue实例挂载的过程4. v-for和v-if优先级5. SPA首屏加载速度慢的原因和解决办法6. Vue中给对象添加新属性界面不刷新(直接给对象添加属性)7. vue组件之间的通信方式有哪些…

14:00面试,14:07就出来了,问的问题有点变态。。。

前言 刚从小厂出来,没想到网盘我在另一家公司又寄了。 在这家公司上班,每天都要加班,但看在钱给的比较多的份上,也就不太计较了。但万万没想到一纸通知,所有人不准加班了,不仅加班费没有了,薪…

JAVA基础学习笔记-day15-File类与IO流

JAVA基础学习笔记-day15-File类与IO流 1. java.io.File类的使用1.1 概述1.2 构造器1.3 常用方法1、获取文件和目录基本信息2、列出目录的下一级3、File类的重命名功能4、判断功能的方法5、创建、删除功能 2. IO流原理及流的分类2.1 Java IO原理2.2 流的分类2.3 流的API 3. 节点…

arthas 内存占用过大排查

使用经验分享 线上故障排查思路: 1、紧急处理,优先保障服务可用(如切换vip,主备容灾) 2、保留第一现场,通过jstack -l {pid} > jvmtmp.txt ,打印栈信息 (后续可以在gceasy官网上…

【Python】内置的type()函数详解和示例

在Python中&#xff0c;type()函数是一个内置函数&#xff0c;用于获取对象的类型。这个函数返回一个对象的类型对象&#xff0c;可以用来比较和识别对象的类型。 # 获取一个整数的类型 print(type(123)) # 输出&#xff1a;<class int># 获取一个字符串的类型 print(t…