ResNext-50模型进行图像识别

embedded/2024/12/21 18:37:38/

本文为为🔗365天深度学习训练营内部文章

原作者:K同学啊

python">import numpy as np
from keras.preprocessing.image import ImageDataGenerator
from keras.utils import to_categorical
from keras.models import Sequential
from keras.layers import Input,Dense,Dropout,Conv2D,MaxPool2D,Flatten,GlobalAvgPool2D,concatenate \,BatchNormalization,Activation,Add,ZeroPadding2D,Lambda
from keras.layers import ReLU
from keras.optimizers import Adam
import matplotlib.pyplot as plt
from keras.callbacks import LearningRateScheduler
from keras.models import Model'''
分组卷积模块
'''
# 定义分组卷积
def grouped_convolution_block(init_x,strides,groups,g_channels):group_list = []# 分组进行卷积for c in range(groups):# 分组取出数据x = Lambda(lambda x:x[:, :, :, c*g_channels:(c+1)*g_channels])(init_x)# 分组进行卷积x = Conv2D(filters=g_channels,kernel_size=(3,3),strides=strides,padding='same',use_bias=False)(x)# 存入listgroup_list.append(x)# 合并list中的数据group_merge = concatenate(group_list,axis=3)x = BatchNormalization(epsilon=1.001e-5)(group_merge)x = ReLU()(x)return x'''
定义残差单元
'''
def block(x,filters,strides=1,groups=32,conv_shortcut=True):if conv_shortcut:shortcut = Conv2D(filters*2,kernel_size=(1,1),strides=strides,padding='same',use_bias=False)(x)# epsilon为BN公式中防止分母为0的值shortcut = BatchNormalization(epsilon=1.001e-5)(shortcut)else:shortcut = x# 三层卷积层x = Conv2D(filters=filters,kernel_size=(1,1),strides=1,padding='same',use_bias=False)(x)x = BatchNormalization(epsilon=1.001e-5)(x)x = ReLU()(x)# 计算每组的通道数g_channels = int(filters / groups)# 分组进行卷积x = grouped_convolution_block(x,strides,groups,g_channels)x = Conv2D(filters=filters*2,kernel_size=(1,1),strides=1,padding='same',use_bias=False)(x)x = BatchNormalization(epsilon=1.001e-5)(x)x = Add()([x,shortcut])x = ReLU()(x)return x'''
堆叠残差单元
'''
def stack(x,filters,blocks,strides,groups=32):# 每个stack的第一个block的残差连接都需要使用1*1卷积升维x = block(x,filters,strides=strides,groups=groups)for i in range(blocks):x = block(x,filters,groups=groups,conv_shortcut=False)return x'''
搭建ResNext-50网络
'''
def ResNext50(input_shape,num_classes):inputs = Input(shape=input_shape)# 填充3圈0,[224,224,3]  -> [230,230,3]x = ZeroPadding2D((3,3))(inputs)x = Conv2D(filters=64,kernel_size=(7,7),strides=2,padding='valid')(x)x = BatchNormalization(epsilon=1.001e-5)(x)x = ReLU()(x)# 填充1圈0x = ZeroPadding2D((1,1))(x)x = MaxPool2D(pool_size=(3,3),strides=2,padding='valid')(x)# 堆叠残差结构x = stack(x,filters=128,blocks=2,strides=1)x = stack(x,filters=256,blocks=3,strides=2)x = stack(x,filters=512,blocks=5,strides=2)x = stack(x,filters=1024,blocks=2,strides=2)# 根据特征图大小进行全局平均池化x = GlobalAvgPool2D()(x)x = Dense(num_classes,activation='softmax')(x)# 定义模型model = Model(inputs=inputs,outputs=x)return modelmodel = ResNext50(input_shape=(224,224,3),num_classes=4)
model.summary()

ResNeXt-50 相比于传统的深层网络(如 ResNet 和 VGG)有明显的优势,特别是在计算效率和模型性能之间找到了较好的平衡。通过引入 卡尔迪纳利性 的概念,ResNeXt-50 能够在网络深度不增加的情况下显著提升模型的能力,同时保持训练的高效性和泛化能力。它适用于各种计算机视觉任务,尤其是在需要高效和准确的图像分类任务中表现出色。


http://www.ppmy.cn/embedded/147600.html

相关文章

linux-----数据库

Linux下数据库概述 数据库类型: 关系型数据库(RDBMS):如MySQL、PostgreSQL、Oracle等。这些数据库以表格的形式存储数据,表格之间通过关系(如主键 - 外键关系)相互关联。关系型数据库支持复杂的…

MCU驱动使用

一、时钟的配置: AG32 通常使用 HSE 外部晶体(范围:4M~16M)。 AG32 中不需要手动设置 PLL 时钟(时钟树由系统自动配置,无须用户关注)。用户只需在配置文件中给出外部晶振频率和系统主频即可。 …

如何确保品牌色在VR虚拟展厅中保持一致性?

确保品牌色在VR虚拟展厅中的一致性对于品牌视觉传达至关重要。品牌色不仅是企业视觉识别系统的重要组成部分,而且在虚拟环境中,它们对于塑造品牌形象和提升用户体验具有决定性作用。 接下来,由专业从事VR虚拟展厅制作的圆桌3D云展厅平台为大家…

【ETCD】【实操篇(四)】etcd常见问题快问快答FAQ

原文:https://etcd.io/docs/v3.5/faq/ 目录 etcd, 一般问题配置相关部署相关操作相关性能相关其他问题 etcd, 一般问题 什么是 etcd? etcd 是一个一致性的分布式键值存储。它主要作为分布式系统中的独立协调服务,设计用于存储可以完全放入内…

如何用发链框架,快速构建一条区块链?

构建一条公链是一个庞大且系统性的工程,涉及技术、生态、市场等多个层面的挑战。特别是在技术层面,必须解决共识机制、可扩展性、安全性以及智能合约的适用性等问题。同时,公链的长期运营和去中心化治理也是不可忽视的难题,令许多…

Qt安装下载太慢解决办法

使用镜像 cmd到安装程序,然后执行命令: qt-online-installer-windows-x64-4.8.1.exe --mirror https://mirrors.ustc.edu.cn/qtproject

sentinel学习笔记1-为什么需要服务降级

本文属于sentinel学习笔记系列。网上看到吴就业老师的专栏,作为官网的有力补充,原文链接如下,讲得好,不要钱,值得推荐,我整理的有所删减,推荐看原文: 深入理解Sentinel 1 为什么需…

spring事件机制笔记、发布和监听

文章目录 为什么要用事件 使用案例可以实现一对多吗? spring事件机制笔记、发布和监听 为什么要用事件 使用案例 可以实现一对多吗?