【机器学习】机器学习的基本分类-半监督学习-Ladder Networks

embedded/2025/1/3 3:28:47/

Ladder Networks 是一种半监督学习模型,通过将无监督学习与监督学习相结合,在标记数据较少的情况下实现高效的学习。它最初由 A. Rasmus 等人在 2015 年提出,特别适合深度学习任务,如图像分类或自然语言处理。


核心思想

Ladder Networks 的目标是利用标记和未标记数据来优化网络性能。其关键思想是引入噪声对网络进行训练,同时通过解码器恢复被破坏的数据结构。它主要由以下三部分组成:

  1. 编码器(Encoder):
    编码器是一个有噪声的前馈神经网络,用于从输入数据生成潜在表示。噪声会加入到各个层的激活值中。

  2. 解码器(Decoder):
    解码器尝试从有噪声的编码器的潜在表示重建无噪声的输入数据。这个过程可以视为自编码器的一部分。

  3. 损失函数(Loss Function):
    损失由两部分组成:

    • 监督损失: 使用标记数据计算的分类误差(如交叉熵)。
    • 重建损失: 解码器重建无噪声表示与原始无噪声数据之间的误差。

通过联合优化这两部分,网络能够同时进行监督学习和无监督学习


模型架构

Ladder Networks 的架构如下:

  • 输入数据经过多层网络,每一层引入噪声,生成一个有噪声的激活值
  • 解码器逐层重建这些激活值,最终输出重建的输入。
  • 使用标记数据进行分类任务,用未标记数据训练解码器,增强表示学习能力。

模型使用跳跃连接(Skip Connections)来帮助解码器更好地恢复无噪声表示。


损失函数

损失函数分为两部分:

  1. 监督损失:
    使用分类任务中的标记数据,例如交叉熵:

    L_{\text{supervised}} = -\sum y \log(\hat{y})
  2. 重建损失:
    解码器的重建误差,例如均方误差:

                                                   L_{\text{reconstruction}} = \sum_{l=1}^L \lambda_l \| z_l - \tilde{z}_l \|^2

    其中,z_l 是无噪声激活值,\tilde{z}_l 是有噪声的激活值的解码结果,\lambda_l​ 是每一层的权重。

总损失是两者的加权和:

L = L_{\text{supervised}} + \alpha L_{\text{reconstruction}}


优势

  1. 高效利用未标记数据:
    通过重建误差,未标记数据在网络训练中也能发挥作用。

  2. 鲁棒性增强:
    加入噪声训练有助于防止过拟合,提高网络的泛化能力。

  3. 层间交互建模:
    跳跃连接有助于捕获层间复杂的相互关系,从而提高表示能力。


应用

  • 图像分类
    在 MNIST、CIFAR-10 等数据集上表现优异,尤其在标记样本少的情况下。

  • 半监督学习
    在需要结合标记数据和未标记数据的任务中具有广泛应用。

  • 自然语言处理:
    用于词嵌入学习或序列生成任务。


示例代码

以下是基于 TensorFlow 的 Ladder Networks 简化实现:

import tensorflow as tf
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.models import Model, Sequential# 噪声函数
def add_noise(x, noise_std=0.3):return x + tf.random.normal(tf.shape(x), stddev=noise_std)# 编码器
def encoder(input_dim, latent_dim, noise_std=0.3):model = Sequential([Dense(128, activation='relu', input_dim=input_dim),Dropout(0.3),Dense(latent_dim, activation='relu'),tf.keras.layers.Lambda(lambda x: add_noise(x, noise_std=noise_std))])return model# 解码器
def decoder(latent_dim, output_dim):model = Sequential([Dense(128, activation='relu', input_dim=latent_dim),Dense(output_dim, activation='sigmoid')  # 重建输入])return model# 输入维度
input_dim = 784  # MNIST 数据集
latent_dim = 64
output_dim = input_dim# 构建模型
encoder_model = encoder(input_dim, latent_dim)
decoder_model = decoder(latent_dim, output_dim)# 输入数据
input_data = tf.keras.Input(shape=(input_dim,))
latent_repr = encoder_model(input_data)
reconstructed = decoder_model(latent_repr)# 定义完整模型
ladder_network = Model(inputs=input_data, outputs=reconstructed)
ladder_network.compile(optimizer='adam', loss='mse')# 示例训练
(X_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
X_train = X_train.reshape(-1, 784).astype('float32') / 255.0ladder_network.fit(X_train, X_train, epochs=10, batch_size=128)

输出结果

Epoch 1/10
469/469 [==============================] - 2s 3ms/step - loss: 0.0471
Epoch 2/10
469/469 [==============================] - 1s 3ms/step - loss: 0.0271
Epoch 3/10
469/469 [==============================] - 2s 3ms/step - loss: 0.0233
Epoch 4/10
469/469 [==============================] - 1s 3ms/step - loss: 0.0215
Epoch 5/10
469/469 [==============================] - 1s 3ms/step - loss: 0.0204
Epoch 6/10
469/469 [==============================] - 1s 3ms/step - loss: 0.0197
Epoch 7/10
469/469 [==============================] - 1s 3ms/step - loss: 0.0191
Epoch 8/10
469/469 [==============================] - 1s 3ms/step - loss: 0.0186
Epoch 9/10
469/469 [==============================] - 1s 3ms/step - loss: 0.0182
Epoch 10/10
469/469 [==============================] - 1s 3ms/step - loss: 0.0178


http://www.ppmy.cn/embedded/150259.html

相关文章

SQL 基础教程 - SQL SELECT 语句

SQL SELECT 语句 SELECT 语句用于从数据库中选取数据。 SQL SELECT 语句 SELECT 语句用于从数据库中选取数据。 结果被存储在一个结果表中,称为结果集。 SQL SELECT 语法 SELECT column1, column2, ... FROM table_name; 与 SELECT * FROM table_name; 参数说…

fgets TAILQ_INSERT_TAIL

If you’re using the macros from <sys/queue.h> to implement a circular doubly linked list (TAILQ), the inversion issue occurs because you’re using LIST_INSERT_HEAD, which inserts at the head of the list. Instead, to maintain the original order (FIFO…

如何做一份出色的PPT?

要做一份出色的PPT&#xff0c;可以遵循以下步骤和建议&#xff1a; 一、明确目标与主题 确定目标&#xff1a;明确PPT的目的&#xff0c;是为了传达什么信息、解决什么问题或达成什么目标。选定主题&#xff1a;根据目标确定一个清晰、聚焦的主题&#xff0c;这将指导整个演…

Live555、FFmpeg、GStreamer介绍

Live555、FFmpeg 和 GStreamer 都是处理流媒体和视频数据的强大开源框架和工具&#xff0c;它们广泛应用于实时视频流的推送、接收、处理和播放。每个框架有不同的设计理念、功能特性以及适用场景。下面将详细分析这三个框架的作用、解决的问题、适用场景、优缺点&#xff0c;并…

mysql日志(

mysql有以下几种日志&#xff1a; log_error 即错误日志&#xff0c;默认是开启的 log_bin 即redo日志或称二进制日志&#xff0c;运于恢复或复制&#xff0c;默认不开启。 general_log 即通用日志&#xff0c;有时也称查询日志&#xff0c;对所有执行过的语句进行记录&#xf…

使用 AI Cursor 编程实现一个小产品 Chrome 扩展插件 MVP 功能

使用 AI Cursor 编程实现一个小产品 Chrome 扩展插件 MVP 功能&#xff0c;提前编写小产品需求技术文档作为上下文&#xff0c;再使用 currsor 单个页面维度生成&#xff0c;能够有效的减少错误&#xff0c;提升开发效率&#xff0c;所以我做了一个[小产品需求技术文档] 提示词…

torch.nn.Sequential的用法

文章目录 介绍基本用法添加命名层动态添加层嵌套使用与自定义前向传播的区别 介绍 torch.nn.Sequential 是 PyTorch 中的一个容器模块&#xff0c;用于将多个神经网络层按顺序组合在一起。它可以让我们以更加简洁的方式定义前向传播的网络结构&#xff0c;适合简单的线性堆叠模…

使用CSS 和 JavaScript 实现鼠标悬停时图片放大、缩小和抖动

我们可以通过 CSS 和 JavaScript 来实现鼠标悬停时图片放大、缩小和抖动的效果。以下是一个简单的实现方式&#xff1a; 1.HTML 结构 <!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><meta name"viewp…