《Keras 3 :当 Recurrence 遇到 Transformers 时》

news/2025/2/15 19:36:08/

《Keras 3 :当 Recurrence 遇到 Transformers 时》

作者:Aritra Roy Gosthipaty, Suvaditya Mukherjee
创建日期:2023/03/12
最后修改时间:2024/11/12
描述:具有时间潜在瓶颈网络的图像分类。

(i) 此示例使用 Keras 3

 在 Colab 中查看 

 GitHub 源


介绍

简单的递归神经网络 (RNN) 对学习时间压缩的表示表现出强烈的归纳偏差。等式 1 显示了递归公式 其中 是整个输入的压缩表示形式(单个向量) 序列。h_tx

方程 1:递归方程。(来源: Aritra 和 Suvaditya)

另一方面,变压器 (Vaswani et al.) 具有 对学习时间压缩表示的归纳偏差很小。 Transformer 在自然语言处理 (NLP) 中取得了 SoTA 结果 和 Vision 任务及其成对注意力机制。

虽然 Transformer 能够处理输入的不同部分 序列,注意力的计算本质上是二次方的。

Didolkar 等人认为,将更压缩的 序列的表示可能有利于泛化,因为它可以很容易地重用重新利用,并且不相关的细节更少。虽然压缩很好, 他们还注意到,过多的 IT 会损害表达能力。

作者提出了一种将计算分为两个流的解决方案。一个慢 stream 和参数化为 变压器。虽然这种方法具有引入不同处理的新颖性 streams 中,它与其他 作品如 Perceiver Mechanism(由 Jaegle 等人)和 Grounded Language Learning Fast and Slow(由 Hill 等人)。

以下示例探讨了如何利用新的 Temporal Latent Bottleneck 在 CIFAR-10 数据集上执行图像分类的机制。我们实现这个 model 进行自定义实现,以便进行高性能矢量化设计。RNNCell


设置导入

import osimport keras
from keras import layers, ops, mixed_precision
from keras.optimizers import AdamW
import numpy as np
import random
from matplotlib import pyplot as plt# Set seed for reproducibility.
keras.utils.set_random_seed(42)

设置所需的配置

我们在我们拥有的管道中设置了一些所需的配置参数 设计。当前参数用于 CIFAR10 数据集。

该模型还支持设置,这将量化模型以尽可能使用浮点数,同时根据需要保留一些参数 用于数值稳定性。这带来了性能优势,因为模型的占用空间 显著降低,同时在推理时带来速度提升。mixed-precision16-bit32-bit

config = {"mixed_precision": True,"dataset": "cifar10","train_slice": 40_000,"batch_size": 2048,"buffer_size": 2048 * 2,"input_shape": [32, 32, 3],"image_size": 48,"num_classes": 10,"learning_rate": 1e-4,"weight_decay": 1e-4,"epochs": 30,"patch_size": 4,"embed_dim": 64,"chunk_size": 8,"r": 2,"num_layers": 4,"ffn_drop": 0.2,"attn_drop": 0.2,"num_heads": 1,
}if config["mixed_precision"]:policy = mixed_precision.Policy("mixed_float16")mixed_precision.set_global_policy(policy)

加载 CIFAR-10 数据集

我们将使用 CIFAR10 数据集来运行我们的实验。此数据集 包含具有标准图像大小的类的训练图像集 之。50,00010(32, 32, 3)

它还具有一组具有相似特征的单独图像。更多 有关数据集的信息也可以在数据集的官方网站上找到 作为 keras.datasets.cifar10 API 参考10,000

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
(x_train, y_train), (x_val, y_val) = ((x_train[: config["train_slice"]], y_train[: config["train_slice"]]),(x_train[config["train_slice"] :], y_train[config["train_slice"] :]),
)

为训练和验证/测试管道定义数据增强

我们定义了单独的管道来对数据执行图像增强。此步骤为 重要的是使模型对更改更健壮,从而帮助它更好地泛化。 我们执行的预处理和增强步骤如下:

  • Rescaling(训练、测试):执行此步骤以标准化所有图像像素 值从 range 到 .这有助于保持数值稳定性 稍后在训练期间领先。[0,255][0,1)
  • Resizing(训练、测试):我们将图像的大小从原始大小 (32, 32) 调整为 (52, 52).这样做是为了考虑 Random Crop(随机裁剪),并遵守 论文中给出的数据的规格。
  • RandomCrop(训练):该层随机选择图像的裁剪/子区域 大小 .(48, 48)
  • RandomFlip(training):该层随机水平翻转所有图像, 保持图像大小不变。
# Build the `train` augmentation pipeline.
train_augmentation = keras.Sequential([layers.Rescaling(1 / 255.0, dtype="float32"),layers.Resizing(config["input_shape"][0] + 20,config["input_shape"][0] + 20,dtype="float32",),layers.RandomCrop(config["image_size"], config["image_size"], dtype="float32"),layers.RandomFlip("horizontal", dtype="float32"),],name="train_data_augmentation",
)# Build the `val` and `test` data pipeline.
test_augmentation = keras.Sequential([layers.Rescaling(1 / 255.0, dtype="float32"),layers.Resizing(config["image_size"], config["image_size"], dtype="float32"),],name="test_data_augmentation",
)# We define functions in place of simple lambda functions to run through the
# [`keras.Sequential`](/api/models/sequential#sequential-class)in order to solve this warning:
# (https://github.com/tensorflow/tensorflow/issues/56089)def train_map_fn(image, label):return train_augmentation(image), labeldef test_map_fn(image, label):return test_augmentation(image), label

将数据集加载到对象中PyDataset

  • 我们获取数据集的实例,并在其周围包装一个类 包装 keras.utils.PyDataset 并使用 keras 应用增强 预处理层。np.ndarray
class Dataset(keras.utils.PyDataset):def __init__(self, x_data, y_data, batch_size, preprocess_fn=None, shuffle=False, **kwargs):if shuffle:perm = np.random.permutation(len(x_data))x_data = x_data[perm]y_data = y_data[perm]self.x_data = x_dataself.y_data = y_dataself.preprocess_fn = preprocess_fnself.batch_size = batch_sizesuper().__init__(*kwargs)def __len__(self):return len(self.x_data) // self.batch_sizedef __getitem__(self, idx):batch_x, batch_y = [], []for i in range(idx * self

http://www.ppmy.cn/news/1572323.html

相关文章

深入理解DOM:22个核心知识点与代码示例

本文系统介绍DOM相关的22个核心概念,每个知识点均提供代码示例及简要说明,帮助开发者全面掌握DOM操作技巧。 一、DOM基础概念 1. DOM概念 DOM(Document Object Model)是HTML/XML的编程接口,通过JavaScript可动态修改…

【RAG落地利器】Weaviate、Milvus、Qdrant 和 Chroma 向量数据库对比

什么是向量数据库? 向量数据库是一种将数据存储为高维向量的数据库,高维向量是特征或属性的数学表示。每个向量都有一定数量的维度,根据数据的复杂性和粒度,可以从数十到数千不等。 向量通常是通过对原始数据(如文本、图像、音频、视频等)…

vue3-虚拟dom优化

Vue 3 对虚拟 DOM(Virtual DOM)进行了全面重构,通过 编译时优化 和 运行时 Diff 算法改进,大幅提升了渲染性能。以下是其核心优化策略和实现细节的详细解析: 一、虚拟 DOM 的基本原理回顾 虚拟 DOM 是一个轻量级的 J…

CSS实现与文字长度相同的下划线

可以使用伪元素和一些样式属性来实现与文字长度相同的下划线。 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0">&…

Python实现从SMS-Activate平台,自动获取手机号和验证码(进阶版2.0)

前言 本文是该专栏的第52篇,后面会持续分享python的各种干货知识,值得关注。 在本专栏之前,笔者在文章《Python实现SMS-Activate接口调用,获取手机号和验证码》中,有详细介绍基于SMS-Activate平台,通过python来实现自动获取目标国家的手机号以及对应的手机号验证码。 而…

如果网络中断,Promise.race 如何处理?

在使用 Promise.race 时&#xff0c;如果网络中断&#xff0c;通常会导致请求失败&#xff0c;并触发相应的错误处理。这可以通过 Promise.race 中的 Promise 对象来捕获。以下是如何处理网络中断的详细说明。 1. 网络中断的处理 当网络中断时&#xff0c;uni.request、uni.u…

无人机雨季应急救灾技术详解

无人机在雨季应急救灾中发挥着至关重要的作用&#xff0c;其凭借机动灵活、反应迅速、高效安全等特点&#xff0c;为救灾工作提供了强有力的技术支撑。以下是对无人机雨季应急救灾技术的详细解析&#xff1a; 一、无人机在雨季应急救灾中的应用场景 1. 灾情侦查与监测 无人机…

2025.2.14——1400

2025.2.14——1400 A 1400 B 1400 C 1400 D 1400 E 1400 F 1400 G 1400 H 1400 ------------------------------------------------ 思维排序/双指针/二分/队列匹配思维二分/位运算思维数学思维 A 一眼想到的是维护信息计数。维护两个信息同时用长的一半去找短的一半…