pytorch实现变分自编码器

server/2025/2/4 15:08:56/

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

变分自编码器(Variational Autoencoder, VAE)是一种生成模型,属于深度学习中的无监督学习方法。它通过学习输入数据的潜在分布(Latent Distribution),生成与输入数据相似的新样本。VAE 可以用于数据生成、降维、异常检测等任务。

VAE 的关键思想是在传统的自编码器(Autoencoder)的基础上,引入了变分推断(Variational Inference)和概率模型,使得网络能够学习到数据的潜在分布,而不仅仅是数据的映射。

VAE 的结构:

  1. 编码器(Encoder):将输入数据映射到潜在空间的分布。不同于传统的自编码器直接将数据映射到一个固定的潜在向量,VAE 通过输出潜在变量的均值和方差来描述一个概率分布,这样潜在空间中的每个点都有一个概率分布。
  2. 潜在空间(Latent Space):表示数据的潜在特征。在 VAE 中,潜在空间的表示是一个分布而不是固定的值。通常,采用正态分布来作为潜在空间的先验分布。
  3. 解码器(Decoder):从潜在空间的样本中重构输入数据。解码器通过将潜在空间的点映射回数据空间来生成样本。

VAE 的目标函数:

VAE 的目标是最大化变分下界(Variational Lower Bound,简称 ELBO),即通过优化以下两部分的加权和:

  • 重构误差(Reconstruction Loss):衡量生成的数据和输入数据之间的差异,通常使用均方误差(MSE)或交叉熵(Cross-Entropy)。
  • KL 散度(KL Divergence):衡量潜在空间的分布与先验分布(通常是标准正态分布)之间的差异。

其最终的目标是使生成的数据尽可能接近真实数据,同时使潜在空间的分布接近先验分布。

优点:

  • VAE 能够生成具有多样性的样本,尤其适用于图像、音频等数据的生成。
  • 潜在空间通常具有良好的结构,可以进行插值、样本生成等操作。

应用:

  • 生成任务:如图像生成、文本生成等。
  • 数据重构:如去噪、自编码等。
  • 半监督学习:VAE 可以结合有标签和无标签的数据进行训练,提升模型的泛化能力。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt# 生成圆形图像的函数(使用PyTorch)
def generate_circle_image(size=64):image = torch.zeros((1, size, size))  # 使用 PyTorch 创建空白图像center = size // 2radius = size // 4for y in range(size):for x in range(size):if (x - center) ** 2 + (y - center) ** 2 <= radius ** 2:image[0, y, x] = 1  # 在圆内的点设置为白色return image# 生成方形图像的函数(使用PyTorch)
def generate_square_image(size=64):image = torch.zeros((1, size, size))  # 使用 PyTorch 创建空白图像padding = size // 4image[0, padding:size - padding, padding:size - padding] = 1  # 设置方形区域为白色return image# 自定义数据集:圆形和方形图像
class ShapeDataset(Dataset):def __init__(self, num_samples=1000, size=64):self.num_samples = num_samplesself.size = sizeself.data = []# 生成数据:一半是圆形图像,一半是方形图像for i in range(num_samples // 2):self.data.append(generate_circle_image(size))self.data.append(generate_square_image(size))def __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx].float()  # 直接返回 PyTorch Tensor 格式的数据# VAE模型定义
class VAE(nn.Module):def __init__(self, latent_dim=2):super(VAE, self).__init__()self.latent_dim = latent_dim# 编码器self.fc1 = nn.Linear(64 * 64, 400)self.fc21 = nn.Linear(400, latent_dim)  # 均值self.fc22 = nn.Linear(400, latent_dim)  # 方差# 解码器self.fc3 = nn.Linear(latent_dim, 400)self.fc4 = nn.Linear(400, 64 * 64)def encode(self, x):h1 = torch.relu(self.fc1(x.view(-1, 64 * 64)))return self.fc21(h1), self.fc22(h1)  # 返回均值和方差def reparameterize(self, mu, logvar):std = torch.exp(0.5 * logvar)eps = torch.randn_like(std)return mu + eps * stddef decode(self, z):h3 = torch.relu(self.fc3(z))return torch.sigmoid(self.fc4(h3)).view(-1, 1, 64, 64)  # 重构图像def forward(self, x):mu, logvar = self.encode(x)z = self.reparameterize(mu, logvar)return self.decode(z), mu, logvar# 损失函数:重构误差 + KL 散度
def loss_function(recon_x, x, mu, logvar):BCE = nn.functional.binary_cross_entropy(recon_x.view(-1, 64 * 64), x.view(-1, 64 * 64), reduction='sum')# KL 散度return BCE + 0.5 * torch.sum(torch.exp(logvar) + mu ** 2 - 1 - logvar)# 设置超参数
batch_size = 128
epochs = 10
latent_dim = 2
learning_rate = 1e-3# 数据加载
train_loader = DataLoader(ShapeDataset(num_samples=2000), batch_size=batch_size, shuffle=True)# 创建模型和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAE(latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# 训练模型
def train(epoch):model.train()train_loss = 0for batch_idx, data in enumerate(train_loader):data = data.to(device)optimizer.zero_grad()recon_batch, mu, logvar = model(data)loss = loss_function(recon_batch, data, mu, logvar)loss.backward()train_loss += loss.item()optimizer.step()if batch_idx % 100 == 0:print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item() / len(data):.6f}')print(f'Train Epoch: {epoch} Average loss: {train_loss / len(train_loader.dataset):.4f}')# 测试并显示一些真实图像和生成的图像
def test():model.eval()with torch.no_grad():# 获取一批真实的图像(原始图像)real_images = next(iter(train_loader))[:64]  # 只取前64个图像real_images = real_images.cpu().numpy()# 从潜在空间随机生成一些样本sample = torch.randn(64, latent_dim).to(device)generated_images = model.decode(sample).cpu().numpy()# 显示真实图像和生成的图像,分别标明fig, axes = plt.subplots(8, 8, figsize=(8, 8))axes = axes.flatten()for i in range(64):if i < 32:  # 前32个显示真实图像axes[i].imshow(real_images[i].squeeze(), cmap='gray')axes[i].set_title('Real', fontsize=8)else:  # 后32个显示生成图像axes[i].imshow(generated_images[i - 32].squeeze(), cmap='gray')axes[i].set_title('Generated', fontsize=8)axes[i].axis('off')plt.tight_layout()plt.show()# 训练模型
for epoch in range(1, epochs + 1):train(epoch)# 训练完成后,显示生成的图像
test()

解释:

  1. 真实图像 (real_images):我们通过 next(iter(train_loader)) 获取一批真实图像,并将其转换为 NumPy 数组,以便 matplotlib 显示。
  2. 生成图像 (generated_images):通过模型生成的图像,使用 decode() 方法生成潜在空间的样本。
  3. 图像展示:前 32 张图像展示真实图像,后 32 张图像展示生成的图像。每个图像上方都有 RealGenerated 标注。

结果:

  • 前32个图像:显示真实图像,并标注为 Real
  • 后32个图像:显示通过训练后的 VAE 生成的图像,并标注为 Generated

http://www.ppmy.cn/server/164917.html

相关文章

亚博microros小车-原生ubuntu支持系列:19 nav2 导航

开始小车测试之前&#xff0c;先补充下背景知识 nav2 Navigation2具有下列工具&#xff1a; 加载、提供和存储地图的工具&#xff08;地图服务器Map Server&#xff09; 在地图上定位机器人的工具 (AMCL) 避开障碍物从A点移动到B点的路径规划工具&#xff08;Nav2 Planner&a…

数据结构的队列

一.队列 1.队列&#xff08;Queue&#xff09;的概念就是先进先出。 2.队列的用法&#xff0c;红色框和绿色框为两组&#xff0c;offer为插入元素&#xff0c;poll为删除元素&#xff0c;peek为查看元素红色的也是一样的。 3.LinkedList实现了Deque的接口&#xff0c;Deque又…

【优先算法】专题——前缀和

目录 一、【模版】前缀和 参考代码&#xff1a; 二、【模版】 二维前缀和 参考代码&#xff1a; 三、寻找数组的中心下标 参考代码&#xff1a; 四、除自身以外数组的乘积 参考代码&#xff1a; 五、和为K的子数组 参考代码&#xff1a; 六、和可被K整除的子数组 参…

一文讲解Java中的ArrayList和LinkedList

ArrayList和LinkedList有什么区别&#xff1f; ArrayList 是基于数组实现的&#xff0c;LinkedList 是基于链表实现的。 二者用途有什么不同&#xff1f; 多数情况下&#xff0c;ArrayList更利于查找&#xff0c;LinkedList更利于增删 由于 ArrayList 是基于数组实现的&#…

数组排序算法

数组排序算法 用C语言实现的数组排序算法。 排序算法平均时间复杂度最坏时间复杂度最好时间复杂度空间复杂度是否稳定适用场景QuickO(n log n)O(n)O(n log n)O(log n)不稳定大规模数据&#xff0c;通用排序BubbleO(n)O(n)O(n)O(1)稳定小规模数据&#xff0c;教学用途InsertO(n)…

二叉树--链式存储

1我们之前学了二叉树的顺序存储&#xff08;这种顺序存储的二叉树被称为堆&#xff09;&#xff0c;我们今天来学习一下二叉树的链式存储&#xff1a; 我们使用链表来表示一颗二叉树&#xff1a; ⽤链表来表⽰⼀棵⼆叉树&#xff0c;即⽤链来指⽰元素的逻辑关系。通常的⽅法是…

Nginx的配置文件 conf/nginx.conf /etc/nginx/nginx.conf 笔记250203

Nginx的配置文件 Nginx 的配置文件是其功能的核心&#xff0c;通过灵活的配置可以实现负载均衡、反向代理、静态资源服务、SSL 加密等功能。以下是 Nginx 配置文件的详细讲解&#xff0c;涵盖结构、核心指令及常见配置场景。 1. 配置文件位置 主配置文件&#xff1a;/etc/ngi…

9.3 GPT Action 设计模式:打造高效的 AI 驱动应用

GPT Action 设计模式:打造高效的 AI 驱动应用 引言:构建智能应用的最佳实践 随着人工智能(AI)的快速发展,开发者正在利用先进的 GPT 模型和技术,创建越来越智能的应用。为了提升开发效率和应用的可扩展性,GPT Action 设计模式应运而生。这个模式为 AI 开发提供了一种灵…