【论文笔记】生成对抗网络 GAN

news/2025/3/26 2:21:34/

GAN

2014 年,Ian Goodfellow 等人提出生成对抗网络(Generative Adversarial Networks),GAN 的出现是划时代的,虽然目前主流的图像/视频生成模型是扩散模型(Diffusion Models)的天下,但是我们仍然有必要了解 GAN 的思想。

GAN 的核心思想是训练两个模型,分别为生成器(Generator)和辨别器(Discriminator),生成器的目标是生成虚假的数据,尽可能混淆辨别器,使其无法判别真实数据和虚假数据,而辨别器的目标则是尽可能将真实数据和虚假数据区分开来。这个过程如下图所示:

gan-example

生成器和辨别器处于一个对抗的过程,它们的能力不断地提升。GAN 的一个缺点在于它的训练过程不稳定,因此在 GAN 出来后,跟 GAN 相关的论文层出不穷,包括改进 GAN 的损失函数、训练方式,或者采用更先进的模型结构,使 GAN 的生成能力更强,同时使其训练过程更加稳定,但是 GAN 的核心思想是不变的。

模型结构

GAN 的结构如下图所示:

gan

GAN 的生成器和辨别器是两个独立的模型,在原始 GAN 中采用的生成器和辨别器都是多层感知机(Multi Layer Perceptron),后来出现了许多模型结构的改进,例如 DCGAN 将 MLP 替换为卷积神经网络。

辨别器

辨别器本质上是一个分类器,用于区分真实数据和由生成器生成的虚假数据,输出是一个 0-1 范围的标量,表示为真实数据的概率值。辨别器有两个数据来源:真实和虚假数据,训练辨别器的过程中,保持生成器的参数不变,利用二分类损失计算梯度,执行反向传播更新辨别器的参数,过程如下。

gan-disc

生成器

生成器用于生成虚假数据,尽可能混淆辨别器,生成器接受一个随机噪声(Random Noise),随机噪声的采样可以来自于均匀分布、正态分布等等,甚至可以是一张图片。生成器的作用就是将随机噪声分布转换为真实数据的分布,在生成器训练的过程中,保持辨别器的参数不变,利用辨别器的梯度来更新生成器。

在这里插入图片描述

损失函数

GAN 采用了 minimax 损失,其数学表达式如下:

min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D,G)=E_{x\sim p_{data}(x)}[\log D(x)]+E_{z\sim p_z(z)}[\log(1-D(G(z)))] GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

其中, V ( D , G ) V(D,G) V(D,G) 表示价值函数, x x x 为真实数据采样的样本, z z z 为生成器生成的样本。

minimax 损失本质上是一个二分类损失(Binary Cross Entropy),可以拆解为辨别器损失和生成器损失。

在训练辨别器的过程中,生成器参数保持不变,因此对于辨别器而言,\(G(z)\) 可以视为常数,其损失函数为:

L D = − E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] − E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] L_D=-E_{x\sim p_{data}(x)}[\log D(x)]-E_{z\sim p_z(z)}[\log(1-D(G(z)))] LD=Expdata(x)[logD(x)]Ezpz(z)[log(1D(G(z)))]

在训练生成器的过程中,辨别器参数保持不变,因此对于辨别器而言,价值函数的第一项为常数,在求导时忽略不计,因此生成器的损失函数为:

L G = − E z ∼ p z ( z ) [ log ⁡ ( D ( G ( z ) ) ) ] L_G=-E_{z\sim p_z(z)}[\log(D(G(z)))] LG=Ezpz(z)[log(D(G(z)))]

对于上述两个损失函数一个直观的理解是,对于 L G L_G LG 而言,我们希望生成器生成的假数据使判别器无法区分,即希望判别器输出的概率接近于 1,取对数后即接近于 0,由于判别器的输出在于 0 - 1 之间,因此取 log 后为负数,即转变为最大化对数概率,或最小化负对数概率,由于优化的过程通常是梯度下降的过程,因此选择后者。

在 GAN 的论文中,给出了一张用于阐述 GAN 的训练过程的图。假设随机噪声 z z z 采样自一维均匀分布,真实数据分布为标准正态分布。图中的黑色点线表示真实数据分布,蓝色虚线表示辨别器输出的概率分布,绿色实线表示生成器输出的概率分布。随着 GAN 的不断训练,生成器生成的数据分布逐渐接近于真实数据分布,辨别器越来越难以区分真实数据和假数据,因此在理想情况下,生成器完全学习到了真实数据分布,辨别器再也无法进行区分,因此输出的概率都为 50%,也就是图(d) 所示的直线。

gan-process

GAN 的训练过程以及 PyTorch 实现

以下是原始 GAN 论文中的训练算法:

train

注意:这里生成器的损失函数并不是前面重写的形式,但是它们两个是等价的,在实际中,作者采用前面重写的形式,因为他们认为这样训练更加稳定。实际的情况是都不那么稳定:)。

下面是一个 GAN 的 PyTorch 实现例子,生成器和辨别器均采用 MLP,在数据集 MNIST 上进行训练的代码,具体代码可见:vanilla-gan。

import os
from argparse import Namespace, ArgumentParser
import torch
from torch import nn, Tensor
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoaderclass Discriminator(nn.Module):"""Disrcminator in GAN.Model Architecture: [affine - leaky relu - dropout] x 3 - affine - sigmoid"""def __init__(self, image_shape: tuple[int, int, int]) -> None:super(Discriminator, self).__init__()C, H, W = image_shapeimage_size = C * H * Wself.model = nn.Sequential(nn.Linear(image_size, 512), nn.LeakyReLU(0.2), nn.Dropout(0.3),nn.Linear(512, 256), nn.LeakyReLU(0.2), nn.Dropout(0.3),nn.Linear(256, 128), nn.LeakyReLU(0.2), nn.Dropout(0.3),nn.Linear(128, 1), nn.Sigmoid())def forward(self, images: Tensor) -> Tensor:images = images.view(images.size(0), -1)return self.model(images)class Generator(nn.Module):"""Generator in GAN.Model Architecture: [affine - batchnorm - relu] x 4 - affine - tanh"""def __init__(self, image_shape: tuple[int, int, int], latent_dim: int) -> None:super(Generator, self).__init__()C, H, W = image_shapeimage_size = C * H * Wself.image_shape = image_shapeself.model = nn.Sequential(nn.Linear(latent_dim, 128), nn.BatchNorm1d(128), nn.ReLU(),nn.Linear(128, 256), nn.BatchNorm1d(256), nn.ReLU(),nn.Linear(256, 512), nn.BatchNorm1d(512), nn.ReLU(),nn.Linear(512, 1024), nn.BatchNorm1d(1024), nn.ReLU(),nn.Linear(1024, image_size), nn.Tanh())def forward(self, z: Tensor) -> Tensor:images: Tensor = self.model(z)return images.view(-1, *self.image_shape)# Image processing.
transform_mnist = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5), std=(0.5))])transform_cifar = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])# Device configuration.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')def denormalize(x: Tensor) -> Tensor:out = (x + 1) / 2return out.clamp(0, 1)def get_args() -> Namespace:"""Get commandline arguments."""parser = ArgumentParser()parser.add_argument('--lr', type=float, default=0.0002, help='learning rate for Adam optimizer')parser.add_argument('--beta1', type=float, default=0.5, help='first momentum term for Adam')parser.add_argument('--beta2', type=float, default=0.999, help='second momentum term for Adam')parser.add_argument('--batch_size', type=int, default=64, help='size of a mini-batch')parser.add_argument('--num_epochs', type=int, default=100, help='training epochs')parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space')parser.add_argument('--dataset', type=str, default='MNIST', help='training dataset(MNIST | FashionMNIST | CIFAR10)')parser.add_argument('--sample_dir', type=str, default='samples', help='directory of image samples')parser.add_argument('--interval', type=int, default=1, help='epoch interval between image samples')parser.add_argument('--logdir', type=str, default='runs', help='directory of running log')parser.add_argument('--ckpt_dir', type=str, default='checkpoints', help='directory for saving model checkpoints')parser.add_argument('--seed', type=str, default=10213, help='random seed')return parser.parse_args()def setup(args: Namespace) -> None:torch.manual_seed(args.seed)# Create directory if not exists.if not os.path.exists(os.path.join(args.sample_dir, args.dataset)):os.makedirs(os.path.join(args.sample_dir, args.dataset))if not os.path.exists(os.path.join(args.ckpt_dir, args.dataset)):os.makedirs(os.path.join(args.ckpt_dir, args.dataset))def get_data_loader(args: Namespace) -> DataLoader:"""Get data loader."""if args.dataset == 'MNIST':data = datasets.MNIST(root='../data', train=True, download=True, transform=transform_mnist)elif args.dataset == 'FashionMNIST':data = datasets.FashionMNIST(root='../data', train=True, download=True, transform=transform_mnist)elif args.dataset == 'CIFAR10':data = datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_cifar)else:raise ValueError(f'Unkown dataset: {args.dataset}, support dataset: MNIST | FashionMNIST | CIFAR10')return DataLoader(dataset=data, batch_size=args.batch_size, num_workers=4, shuffle=True)def train(args: Namespace, G: Generator, D: Discriminator, data_loader: DataLoader) -> None:"""Train Generator and Discriminator.Args:args(Namespace): arguments.G(Generator): Generator in GAN.D(Discriminator): Discriminator in GAN."""writer = SummaryWriter(os.path.join(args.logdir, args.dataset))# generate fixed noise for sampling.fixed_noise = torch.rand(64, args.latent_dim).to(device)# Loss and optimizer.criterion = nn.BCELoss().to(device)optimizer_G = torch.optim.Adam(G.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))optimizer_D = torch.optim.Adam(D.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))# Start training.for epoch in range(args.num_epochs):total_d_loss = total_g_loss = 0for images, _ in data_loader:m = images.size(0)images: Tensor = images.to(device)images = images.view(m, -1)# Create real and fake labels.real_labels = torch.ones(m, 1).to(device)fake_labels = torch.zeros(m, 1).to(device)# ================================================================== ##                      Train the discriminator                       ## ================================================================== ## Forward passoutputs = D(images)d_loss_real: Tensor = criterion(outputs, real_labels)z = torch.rand(m, args.latent_dim).to(device)fake_images: Tensor = G(z).detach()outputs = D(fake_images)d_loss_fake: Tensor = criterion(outputs, fake_labels)# Backward passd_loss: Tensor = d_loss_real + d_loss_fakeoptimizer_D.zero_grad()d_loss.backward()optimizer_D.step()total_d_loss += d_loss# ================================================================== ##                        Train the generator                         ## ================================================================== ## Forward passz = torch.rand(images.size(0), args.latent_dim).to(device)fake_images: Tensor = G(z)outputs = D(fake_images)# Backward passg_loss: Tensor = criterion(outputs, real_labels)optimizer_G.zero_grad()g_loss.backward()optimizer_G.step()total_g_loss += g_lossprint(f'''
=====================================
Epoch: [{epoch + 1}/{args.num_epochs}]
Discriminator Loss: {total_d_loss / len(data_loader):.4f}
Generator Loss: {total_g_loss / len(data_loader):.4f}
=====================================''')# Log Discriminator and Generator loss.writer.add_scalar('Discriminator Loss', total_d_loss / len(data_loader), epoch + 1)writer.add_scalar('Generator Loss', total_g_loss / len(data_loader), epoch + 1)fake_images: Tensor = G(fixed_noise)img_grid = make_grid(denormalize(fake_images), nrow=8, padding=2)writer.add_image('Fake Images', img_grid, epoch + 1)if (epoch + 1) % args.interval == 0:save_image(img_grid, os.path.join(args.sample_dir, args.dataset, f'fake_images_{epoch + 1}.png'))# Save the model checkpoints.torch.save(G.state_dict(), os.path.join(args.ckpt_dir, args.dataset, 'G.ckpt'))torch.save(D.state_dict(), os.path.join(args.ckpt_dir, args.dataset, 'D.ckpt'))def main() -> None:args = get_args()setup(args)image_shape = (1, 28, 28) if args.dataset in ('MNIST', 'FashionMNIST') else (3, 32, 32)data_loader = get_data_loader(args)# Generator and Discrminator.G = Generator(image_shape=image_shape, latent_dim=args.latent_dim).to(device)D = Discriminator(image_shape=image_shape).to(device)train(args, G, D, data_loader)if __name__ == '__main__':main()

参考

[1] I. Goodfellow et al., “Generative Adversarial Nets,” in Advances in Neural Information Processing Systems, Curran Associates, Inc., 2014. Accessed: Sep. 12, 2024. [Online]. Available: https://papers.nips.cc/paper_files/paper/2014/hash/5ca3e9b122f61f8f06494c97b1afccf3-Abstract.html

[2] eriklindernoren. “PyTorch-GAN”. Github 2018. [Online]. Available: https://github.com/eriklindernoren/PyTorch-GAN

[3] 李沐. “GAN论文逐段精读【论文精读】”. Bilibili 2021. [Online]. Available: https://www.bilibili.com/video/BV1rb4y187vD/?spm_id_from=333.1387.collection.video_card.click&vd_source=c8a32a5a667964d5f1068d38d6182813
n. “PyTorch-GAN”. Github 2018. [Online]. Available: https://github.com/eriklindernoren/PyTorch-GAN


http://www.ppmy.cn/news/1583106.html

相关文章

HarmonyOS Next~鸿蒙AI功能开发:Core Speech Kit与Core Vision Kit的技术解析与实践

HarmonyOS Next~鸿蒙AI功能开发:Core Speech Kit与Core Vision Kit的技术解析与实践 一、鸿蒙AI功能开发的生态定位与核心能力 在鸿蒙操作系统(HarmonyOS)的生态布局中,AI功能开发是提升用户体验与设备智能化的核心方…

debian12运行sql server2022(docker):导入.MDF .LDF文件到容器

过程大纲 docker run在基础配置之上增加挂载信息 修改文件权限,确保所有用户有rw权限 进入docker交互命令行 登录数据库 执行数据库EXE命令导入数据库文件数据 docker run在基础配置之上增加挂载信息 docker run -d \-v /home/ying/Downloads/StuXk:/var/opt/mssql…

关于 Redis 缓存一致

为了提升系统性能,常常会引入 Redis 作为缓存。数据通常会存储在持久化的数据源(如 MySQL 数据库)中,同时在 Redis 中保存一份副本。当数据源中的数据发生变化时,如果不能及时同步到 Redis 缓存,或者缓存中…

Netty源码—3.Reactor线程模型四

大纲 5.NioEventLoop的执行总体框架 6.Reactor线程执行一次事件轮询 7.Reactor线程处理产生IO事件的Channel 8.Reactor线程处理任务队列之添加任务 9.Reactor线程处理任务队列之执行任务 10.NioEventLoop总结 8.Reactor线程处理任务队列之添加任务 (1)Reactor线程执行一…

带你了解Java无锁并发CAS

带你了解Java无锁并发CAS 在多核处理器时代,并发编程已成为提升系统性能的核心手段。传统的同步机制(如synchronized和ReentrantLock)通过互斥锁实现线程安全,但其存在以下关键问题: 性能损耗:线程阻塞/唤…

常见中间件漏洞攻略-Tomcat篇

一、 CVE-2017-12615-Tomcat put方法任意文件写入漏洞 第一步:开启靶场 第二步:在首页抓取数据包,并发送到重放器 第三步:先上传尝试一个1.txt进行测试 第四步:上传后门程序 第五步:使用哥斯拉连接 二、后…

《Python实战进阶》No26: CI/CD 流水线:GitHub Actions 与 Jenkins 集成

No26: CI/CD 流水线:GitHub Actions 与 Jenkins 集成 摘要 持续集成(CI)和持续部署(CD)是现代软件开发中不可或缺的实践,能够显著提升开发效率、减少错误并加速交付流程。本文将探讨如何利用 GitHub Actio…

1 存储过程学习: 使用DMSQL程序的优点

DMSQL程序具有以下优点: 与SQL语言的完美结合 SQL语言已成为数据库的标准语言,DMSQL程序支持所有SQL数据类型和所有SQL函数,同时支持所有DM对象类型。在DMSQL程序中可以使用SELECT、INSERT、DELETE、UPDATE数据操作语句,事务控制…