深度学习笔记11-优化器对比实验(Tensorflow)

server/2025/1/11 7:55:44/
  • 🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊

目录

一、导入数据并检查

二、配置数据集

三、数据可视化

四、构建模型

五、训练模型

六、模型对比评估

七、总结


一、导入数据并检查

import pathlib,PIL
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签data_dir    = pathlib.Path("./T6")
image_count = len(list(data_dir.glob('*/*')))
batch_size = 16
img_height = 336
img_width  = 336
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=12,image_size=(img_height, img_width),batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=12,image_size=(img_height, img_width),batch_size=batch_size)

class_names = train_ds.class_names
print(class_names)

for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break

二、配置数据集

AUTOTUNE = tf.data.AUTOTUNE
#归一化处理
def train_preprocessing(image,label):return (image/255.0,label)train_ds = (train_ds.cache().shuffle(1000).map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)           # 在image_dataset_from_directory处已经设置了batch_size.prefetch(buffer_size=AUTOTUNE)
)val_ds = (val_ds.cache().shuffle(1000).map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)         # 在image_dataset_from_directory处已经设置了batch_size.prefetch(buffer_size=AUTOTUNE)
)

三、数据可视化

plt.figure(figsize=(10, 8))  # 图形的宽为10高为5
plt.suptitle("数据展示")for images, labels in train_ds.take(1):for i in range(15):plt.subplot(4, 5, i + 1)plt.xticks([])plt.yticks([])plt.grid(False)# 显示图片plt.imshow(images[i])# 显示标签plt.xlabel(class_names[labels[i]-1])plt.show()

四、构建模型

from tensorflow.keras.layers import Dropout,Dense,BatchNormalization
from tensorflow.keras.models import Modeldef create_model(optimizer='adam'):# 加载预训练模型vgg16_base_model = tf.keras.applications.vgg16.VGG16(weights='imagenet',include_top=False,#不包含顶层的全连接层input_shape=(img_width, img_height, 3),pooling='avg')#平均池化层替代顶层的全连接层for layer in vgg16_base_model.layers:layer.trainable = False  #将 trainable属性设置为 False 意味着在训练过程中,这些层的权重不会更新X = vgg16_base_model.outputX = Dense(170, activation='relu')(X)X = BatchNormalization()(X)X = Dropout(0.5)(X)output = Dense(len(class_names), activation='softmax')(X)#神经元数量等于类别数vgg16_model = Model(inputs=vgg16_base_model.input, outputs=output)vgg16_model.compile(optimizer=optimizer,loss='sparse_categorical_crossentropy',metrics=['accuracy'])return vgg16_modelmodel1 = create_model(optimizer=tf.keras.optimizers.Adam())
model2 = create_model(optimizer=tf.keras.optimizers.SGD())#随机梯度下降(SGD)优化器的
model2.summary()

五、训练模型

NO_EPOCHS = 20history_model1  = model1.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)
history_model2  = model2.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)

六、模型对比评估

from matplotlib.ticker import MultipleLocator
plt.rcParams['savefig.dpi'] = 300 #图片像素
plt.rcParams['figure.dpi']  = 300 #分辨率acc1     = history_model1.history['accuracy']
acc2     = history_model2.history['accuracy']
val_acc1 = history_model1.history['val_accuracy']
val_acc2 = history_model2.history['val_accuracy']loss1     = history_model1.history['loss']
loss2     = history_model2.history['loss']
val_loss1 = history_model1.history['val_loss']
val_loss2 = history_model2.history['val_loss']epochs_range = range(len(acc1))plt.figure(figsize=(16, 4))
plt.subplot(1, 2, 1)plt.plot(epochs_range, acc1, label='Training Accuracy-Adam')
plt.plot(epochs_range, acc2, label='Training Accuracy-SGD')
plt.plot(epochs_range, val_acc1, label='Validation Accuracy-Adam')
plt.plot(epochs_range, val_acc2, label='Validation Accuracy-SGD')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss1, label='Training Loss-Adam')
plt.plot(epochs_range, loss2, label='Training Loss-SGD')
plt.plot(epochs_range, val_loss1, label='Validation Loss-Adam')
plt.plot(epochs_range, val_loss2, label='Validation Loss-SGD')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))plt.show()

可以看出,在这个实例中,Adam优化器的效果优于SGD优化器

七、总结

      通过本次实验,学会了比较不同优化器(Adam和SGD)在训练过程中的性能表现,可视化训练过程的损失曲线和准确率等指标。这是一项非常重要的技能,在研究论文中,可以通过这些优化方法可以提高工作量。


http://www.ppmy.cn/server/157405.html

相关文章

uniapp-vue3 实现, 一款带有丝滑动画效果的单选框组件,支持微信小程序、H5等多端

采用 uniapp-vue3 实现, 是一款带有丝滑动画效果的单选框组件,提供点状、条状的动画过渡效果,支持多项自定义配置,适配 web、H5、微信小程序(其他平台小程序未测试过,可自行尝试) 可到插件市场下载尝试&…

《零基础Go语言算法实战》【题目 1-23】map 错误排查

《零基础Go语言算法实战》 【题目 1-23】map 错误排查 请解释以下代码在执行时为什么会报错。 type Books struct { name string } func main() { m : map[string]Books{"name": {"《零基础 Go 语言算法实战》"}} m["book"].name "《…

远程和本地文件的互相同步

文章目录 1、rsync实现类似git push pull功能1. 基础概念2. 示例操作3. 定制化和进阶用法4. 定时同步(类似自动化) 2 命令简化1. 动态传参的脚本2. Shell 函数支持动态路径3. 结合环境变量和参数(更简洁)4. Makefile 支持动态路径…

Perl语言的软件开发工具

Perl语言的软件开发工具 引言 Perl是一种功能强大且灵活的高级编程语言,自1987年由拉里沃尔(Larry Wall)创建以来,就广泛应用于文本处理、系统管理、网络编程、Web开发等多个领域。作为一种脚本语言,Perl以其简洁的语…

iOS - 消息机制

1. 基本数据结构 // 方法结构 struct method_t {SEL name; // 方法名const char *types; // 类型编码IMP imp; // 方法实现 };// 类结构 struct objc_class {Class isa;Class superclass;cache_t cache; // 方法缓存class_data_bits_t bits; // 类的方法…

ubuntu22.04 的录屏软件有哪些?

在Ubuntu 22.04上,有几款适合做视频直播和录屏的软件: 1. OBS Studio (Open Broadcaster Software) 功能:OBS Studio 是最常用的开源直播和录屏软件,支持视频录制、直播流式传输,并且有强大的插件支持,能…

视频编辑最新SOTA!港中文Adobe等发布统一视频生成传播框架——GenProp

文章链接:https://arxiv.org/pdf/2412.19761 项目链接:https://genprop.github.io 亮点直击 定义了一个新的生成视频传播问题,目标是利用 I2V 模型的生成能力,将视频第一帧的各种变化传播到整个视频中。 精心设计了模型 GenProp&…

怎么修复损坏或者语法有问题的PDF-免费PDF编辑工具分享

序言 我之前的文章也有介绍过如何使用96缔盟PDF处理器修复破损或者语法有问题的PDF文件,但是当时是使用DMPDFUtilTool1.0版本进行的,V1.0的功能尚不完善,存在一些隐藏的功能BUG,而且在用户体验方面也存在一些不足,例如还不支持拖…