T10 tensorflow数据增强

devtools/2024/10/17 21:15:59/
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

T10 使用 TensorFlow 实现数据增强

在深度学习的图像分类任务中,数据增强是一种常用的技术,它通过对现有训练样本进行随机变换(例如翻转、旋转、缩放等),以生成更多的训练数据,帮助模型更好地泛化,提升模型在未知数据上的表现。Pytorch框架数据增强方式较为方便,但对于tensorflow还不熟悉,这周主要学习tensorflow框架下数据增强方法。

1. 环境设置和数据加载

将数据集分为训练集、验证集和测试集。

python"># 设置 GPU 显存按需使用
gpus = tf.config.list_physical_devices("GPU")
if gpus:tf.config.experimental.set_memory_growth(gpus[0], True)tf.config.set_visible_devices([gpus[0]], "GPU")# 数据路径和参数设定
data_dir = "./34-data/"
img_height, img_width = 224, 224
batch_size = 32# 加载训练和验证数据集
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.3,subset="training",seed=12,image_size=(img_height, img_width),batch_size=batch_size)val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.3,subset="validation",seed=12,image_size=(img_height, img_width),batch_size=batch_size)# 分割验证集为验证集和测试集
val_batches = tf.data.experimental.cardinality(val_ds)
test_ds = val_ds.take(val_batches // 5)
val_ds = val_ds.skip(val_batches // 5)print(f'Number of validation batches: {tf.data.experimental.cardinality(val_ds)}')
print(f'Number of test batches: {tf.data.experimental.cardinality(test_ds)}')
2. 数据增强

数据增强 是提高模型泛化能力的重要技术。通过随机改变图像的属性,如水平/垂直翻转、旋转等,模型可以学会更好地处理不同的图像变体。

在 Keras 中,我们可以使用 tf.keras.layers.RandomFliptf.keras.layers.RandomRotation 等层来实现数据增强。在这里,我们对图像进行随机水平和垂直翻转,并进行一定程度的旋转:

python"># 定义数据增强操作
data_augmentation = tf.keras.Sequential([tf.keras.layers.RandomFlip("horizontal_and_vertical"),tf.keras.layers.RandomRotation(0.2),
])# 数据增强效果展示
import matplotlib.pyplot as pltplt.figure(figsize=(8, 8))
for images, labels in train_ds.take(1):image = tf.expand_dims(images[0], 0)for i in range(9):augmented_image = data_augmentation(image)ax = plt.subplot(3, 3, i + 1)plt.imshow(augmented_image[0])plt.axis("off")
plt.show()

在上面的代码中,data_augmentation 是一个 Keras 序列模型,它会对输入的图像进行随机的翻转和旋转。通过上面的可视化代码,我们可以直观地看到数据增强后的图像效果。
在这里插入图片描述

3. 数据预处理和模型训练

数据增强可以集成到训练管道中。在训练集上,我们通过 map 函数应用数据增强,同时将图像归一化为 [0, 1] 区间以便模型训练。

python">AUTOTUNE = tf.data.AUTOTUNEdef preprocess_image(image, label):image = image / 255.0  # 图像归一化return image, label# 预处理和缓存数据集
train_ds = train_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
val_ds = val_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
test_ds = test_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)# 将数据增强应用到训练集
def prepare(ds):ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), num_parallel_calls=AUTOTUNE)return dstrain_ds = prepare(train_ds)

prepare 函数中,我们将数据增强操作应用于训练集,并通过 AUTOTUNE 进行多线程处理,加快数据读取速度。

4. 模型搭建和训练

定义了一个简单的卷积神经网络,包含三层卷积层,最后通过全连接层进行分类。模型使用 Adam 优化器,并通过交叉熵损失函数来优化。

python">model = tf.keras.Sequential([layers.Conv2D(16, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Conv2D(32, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Conv2D(64, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Flatten(),layers.Dense(128, activation='relu'),layers.Dense(len(class_names))
])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])# 训练模型
epochs = 20
history = model.fit(train_ds,validation_data=val_ds,epochs=epochs
)

结果如下:
在这里插入图片描述

5. 模型评估

在模型训练完成后,我们可以通过测试集评估模型的表现。

python"># 评估模型
loss, acc = model.evaluate(test_ds)
print(f"Test Accuracy: {acc}")

Accuracy 0.90625

6. 总结

这周学习了如何使用 TensorFlow 和 Keras 实现一个包含数据增强的图像分类任务。数据增强在提升模型泛化能力上有显著作用,尤其在训练样本有限的情况下,随机翻转、旋转等操作能够帮助模型学习到更多的图像变体,从而在测试集上取得更好的表现。


http://www.ppmy.cn/devtools/126564.html

相关文章

使用人体关键点驱动FBX格式虚拟人原理【详解】

文章目录 1、使用人体关键点数据驱动FBX格式虚拟人的总流程2、使用mediapipe检测人体关键点和插值平滑2.1 mediapipe检测人体关键点2.2 人体关键点的插值平滑 3、将2d关键点转为3d关键点4、旋转矩阵4.1 旋转矩阵4.2 旋转矩阵转为四元数 5、将旋转矩阵用于虚拟人的驱动5.1 基础旋…

Linux 操作系统——扫盲教程5

目录 更多的 Machine Related 指令 useradd passwd ps top nice pgrep ifconfig iostat iotop mpstat vmstat 更多的 Machine Related 指令 useradd 各位如果有自己装Linux发行版的经验,就会知道我们的操作系统需要注册一个用户,我们登陆上…

中国全国省市区县汇总全国省市区json省市区数据2024最新

简介 包含全国省市区县数据,共3465个。 全国总共有23个省、5个自治区、4个直辖市、2个特别行政区。 ——更新于2024年10月16日,从2017年开始,已经更新坚持7年 从刚开始1000个左右的城市json,到现在全国省市区县3465个。 本人感觉应该是目前最完善的~ 每年都在更新中,…

外包功能测试干了6个月,技术退步太明显了。。。。。

先说一下自己的情况,本科生,23年通过校招进入武汉某软件公司,干了差不多6个月的功能测试,今年中秋,感觉自己不能够在这样下去了,长时间呆在一个舒适的环境会让一个人堕落!而我就在一个外包企业干了6个月的功…

Python网络爬虫

随着互联网的迅猛发展,数据成为了新的“石油”。人们对于信息的需求日益增涨,尤其是在市场分析、学术研究和数据挖掘等领域。网络爬虫作为一种自动提取网络数据的技术,因其强大的能力而备受关注。而Python,凭借其简洁的语法和丰富…

【一个简单的JavaScript网页设计案例】

首先&#xff0c;我们需要一些HTML来构建基本的页面结构&#xff0c;接着是一些CSS来美化页面&#xff0c;最后是JavaScript来实现功能。 HTML (index.html) <!DOCTYPE html> <html lang"zh"> <head> <meta charset"UTF-8"> <…

使用API有效率地管理Dynadot域名,删除域名服务器(Name Server)

前言 Dynadot是通过ICANN认证的域名注册商&#xff0c;自2002年成立以来&#xff0c;服务于全球108个国家和地区的客户&#xff0c;为数以万计的客户提供简洁&#xff0c;优惠&#xff0c;安全的域名注册以及管理服务。 Dynadot平台操作教程索引&#xff08;包括域名邮箱&…

搜维尔科技:遥操作方案定制,视觉识别映射灵巧手

遥操作方案定制&#xff0c;视觉识别映射灵巧手 搜维尔科技&#xff1a;遥操作方案定制&#xff0c;视觉识别映射灵巧手