【机器学习】生成对抗网络 – Generative Adversarial Networks(GAN)的基本概念和代码示例

embedded/2024/10/21 9:38:30/

引言

生成对抗网络(GANs)是深度学习中的一种创新模型,自2014年由Ian Goodfellow等人首次提出以来,已成为深度学习领域中最活跃的研究方向之一

文章目录

  • 引言
  • 一、GANs的基本概念
    • 1.1 GANs的基本原理和结构
      • 1.1.1 生成器
      • 1.1.2 判别器
      • 1.1.3 大白话版本
    • 1.2 GANs的训练过程
    • 1.3 GAN的优缺点
      • 1.3.1 优点
      • 1.3.2 缺陷
    • 1.4 GANs的变体
      • 1.4.1 Conditional GAN (cGAN)
      • 1.4.2 Deep Convolutional GAN (DCGAN)
      • 1.4.3 Wasserstein GAN (WGAN)
      • 1.4.4 StyleGAN
    • 1.5 GANs的应用领域
    • 1.6 总结
    • 1.7 GAN代码实例
      • 1.7.1 代码
      • 1.7.2 代码解释

一、GANs的基本概念

1.1 GANs的基本原理和结构

在这里插入图片描述

1.1.1 生成器

生成器的任务是生成新的数据样本。它从随机噪声中学习生成与真实数据相似的数据。在训练过程中,生成器和判别器进行对抗训练,生成器不断优化生成的数据样本,以达到欺骗判别器的目的

1.1.2 判别器

判别器的任务则是判断输入数据是否真实,即尝试区分生成的数据和真实数据,判别器则努力提高区分真实与生成数据的能力。这种竞争推动了模型的不断进化,使得生成的数据逐渐接近真实数据分布

1.1.3 大白话版本

  1. 假设一个城市治安混乱,很快,这个城市里就会出现无数的小偷。在这些小偷中,有的可能是盗窃高手,有的可能毫无技术可言。假如这个城市开始整饬其治安,突然开展一场打击犯罪的“运动”,警察们开始恢复城市中的巡逻,很快,一批“学艺不精”的小偷就被捉住了。之所以捉住的是那些没有技术含量的小偷,是因为警察们的技术也不行了,在捉住一批低端小偷后,城市的治安水平变得怎样倒还不好说,但很明显,城市里小偷们的平均水平已经大大提高了

  2. 警察严打导致小偷水平提升
    警察们开始继续训练自己的破案技术,开始抓住那些越来越狡猾的小偷。随着这些职业惯犯们的落网,警察们也练就了特别的本事,他们能很快能从一群人中发现可疑人员,于是上前盘查,并最终逮捕嫌犯;小偷们的日子也不好过了,因为警察们的水平大大提高,如果还想以前那样表现得鬼鬼祟祟,那么很快就会被警察捉住

  3. 经常提升技能,更多小偷被抓
    为了避免被捕,小偷们努力表现得不那么“可疑”,而魔高一尺、道高一丈,警察也在不断提高自己的水平,争取将小偷和无辜的普通群众区分开。随着警察和小偷之间的这种“交流”与“切磋”,小偷们都变得非常谨慎,他们有着极高的偷窃技巧,表现得跟普通群众一模一样,而警察们都练就了“火眼金睛”,一旦发现可疑人员,就能马上发现并及时控制

  4. 最终,我们同时得到了最强的小偷和最强的警察

1.2 GANs的训练过程

在这里插入图片描述

  1. 初始化生成器和判别器的权重
  2. 在一个批次的数据上训练判别器,使其能够区分真实数据和生成数据
  3. 训练生成器,使其生成的假数据能够欺骗判别器,提高判别器的错误率
  4. 重复步骤2和3,直到达到预设的训练轮数或满足一定的性能指标

1.3 GAN的优缺点

1.3.1 优点

  1. 能更好建模数据分布(图像更锐利、清晰)
  2. 理论上,GANs 能训练任何一种生成器网络。其他的框架需要生成器网络有一些特定的函数形式,比如输出层是高斯的
  3. 无需利用马尔科夫链反复采样,无需在学习过程中进行推断,没有复杂的变分下界,避开近似计算棘手的概率的难题

1.3.2 缺陷

  1. 难训练,不稳定。生成器和判别器之间需要很好的同步,但是在实际训练中很容易D收敛,G发散。D/G 的训练需要精心的设计
  2. 模式缺失(Mode Collapse)问题。GANs的学习过程可能出现模式缺失,生成器开始退化,总是生成同样的样本点,无法继续学习

1.4 GANs的变体

为了改善其训练稳定性、提高生成质量、扩展应用范围等,研究人员提出了许多GAN的变体

1.4.1 Conditional GAN (cGAN)

引入条件变量,生成特定类别的样本

1.4.2 Deep Convolutional GAN (DCGAN)

使用卷积层和反卷积层,提高图像生成的质量和稳定性

1.4.3 Wasserstein GAN (WGAN)

改变损失函数,使用Wasserstein距离,改善训练稳定性和模式覆盖率

1.4.4 StyleGAN

引入风格分离的概念,控制生成图像的局部属性

1.5 GANs的应用领域

在这里插入图片描述

GANs的应用非常广泛,包括但不限于:

  • 图像生成:风格迁移、超分辨率、人脸生成等
  • 数据增强:生成额外的样本以增强训练集
  • 医学图像分析:生成医学图像以辅助诊断
  • 声音合成:生成或修改语音信号
  • 化学分子设计:设计新的分子结构,加速药物研发和材料设计过程

1.6 总结

在实践中,训练GANs需要注意权重的初始化、优化器的选择等因素,以确保训练的稳定性和效果

1.7 GAN代码实例

1.7.1 代码

python">#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2024/8/15 16:34
# @Software: PyCharm
# @Author  : xialiwei
# @Email   : xxxxlw198031@163.com
import timeimport numpy as np
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Reshape
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import LeakyReLU
import matplotlib.pyplot as plt# 加载MNIST数据集
(x_train, _), (_, _) = mnist.load_data()# 数据预处理
x_train = (x_train.astype(np.float32) - 127.5) / 127.5
x_train = np.expand_dims(x_train, axis=3)# 生成器模型
def build_generator(z_dim):model = Sequential()model.add(Dense(128, input_dim=z_dim))model.add(LeakyReLU(alpha=0.01))model.add(Dense(28 * 28 * 1, activation='tanh'))model.add(Reshape((28, 28, 1)))return model# 判别器模型
def build_discriminator(img_shape):model = Sequential()model.add(Flatten(input_shape=img_shape))model.add(Dense(128))model.add(LeakyReLU(alpha=0.01))model.add(Dense(1, activation='sigmoid'))return model# GAN模型
def build_gan(generator, discriminator):model = Sequential()model.add(generator)model.add(discriminator)return model# 设置参数
z_dim = 100
img_shape = (28, 28, 1)# 创建生成器和判别器
generator = build_generator(z_dim)
discriminator = build_discriminator(img_shape)
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(), metrics=['accuracy'])# 对生成器进行编译,但在训练GAN模型时不训练判别器部分
discriminator.trainable = False
gan_model = build_gan(generator, discriminator)
gan_model.compile(loss='binary_crossentropy', optimizer=Adam())# 训练GAN
epochs = 100
batch_size = 256
sample_interval = 50# 训练过程中真样本标签为1,假样本标签为0
real = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))for epoch in range(epochs):# 训练判别器idx = np.random.randint(0, x_train.shape[0], batch_size)real_imgs = x_train[idx]z = np.random.normal(0, 1, (batch_size, z_dim))fake_imgs = generator.predict(z)d_loss_real = discriminator.train_on_batch(real_imgs, real)d_loss_fake = discriminator.train_on_batch(fake_imgs, fake)d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# 训练生成器z = np.random.normal(0, 1, (batch_size, z_dim))g_loss = gan_model.train_on_batch(z, real)if epoch % sample_interval == 0:print(f"{epoch} [D loss: {d_loss[0]}, acc.: {100 * d_loss[1]}] [G loss: {g_loss}]")# 这里可以添加代码来保存或显示生成的图片样本# 显示生成的图片样本z = np.random.normal(0, 1, (25, z_dim))  # 生成25个样本gen_imgs = generator.predict(z)# 将生成的图片数据转换为0-1范围gen_imgs = 0.5 * gen_imgs + 0.5# 绘制生成的图片fig, axs = plt.subplots(5, 5)cnt = 0for i in range(5):for j in range(5):axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')axs[i, j].axis('off')cnt += 1plt.show()  # 显示图形plt.pause(3)  # 暂停3秒plt.close(fig)  # 关闭图形,防止阻塞
# 保存模型
generator.save('generator.h5')
discriminator.save('discriminator.h5')

输出结果:
在这里插入图片描述

1.7.2 代码解释

这段代码是一个生成对抗网络(GAN)的完整实现,用于生成类似于MNIST数据集中的手写数字图像


http://www.ppmy.cn/embedded/97518.html

相关文章

ArkTs基础语法-声明式UI-自定义组件

这里写目录标题 自定义组件创建自定义组件自定义组件的基本用法自定义组件的基本结构成员函数/变量自定义组件的参数规定build()函数 自定义组件 创建自定义组件 自定义组件的基本用法 Component export struct MyComponent {State message: string 我是自定义组件;build()…

学习文件IO,让你从操作系统内核的角度去理解输入和输出(理论篇)

本篇会加入个人的所谓鱼式疯言 ❤️❤️❤️鱼式疯言:❤️❤️❤️此疯言非彼疯言 而是理解过并总结出来通俗易懂的大白话, 小编会尽可能的在每个概念后插入鱼式疯言,帮助大家理解的. 🤭🤭🤭可能说的不是那么严谨.但小编初心是能让更多人…

qt-16可扩展对话框--隐藏和展现

可扩展对话框 知识点extension.hextension.cppmain.cpp运行图初始化隐藏展现--点击--详细按钮 知识点 MainLayout->setSizeConstraint(QLayout::SetFixedSize);//固定窗口大小 extension.h #ifndef EXTENSION_H #define EXTENSION_H#include <QDialog>class Extens…

java jvm默认用的垃圾回收器

Java JVM&#xff08;Java虚拟机&#xff09;的默认垃圾回收器取决于所使用的Java版本。以下是不同Java版本下JVM默认垃圾回收器的概述&#xff1a; JDK 8及之前版本 默认垃圾回收器&#xff1a;Parallel Scavenge&#xff08;新生代&#xff09; Serial Old&#xff08;老年…

超详细!!!electron-vite-vue开发桌面应用之引入UI组件库element-plus(四)

云风网 云风笔记 云风知识库 一、安装element-plus以及图标库依赖 npm install element-plus --save npm install element-plus/icons-vue npm i -D unplugin-icons二、vite按需引入插件 npm install -D unplugin-vue-components unplugin-auto-importunplugin-vue-componen…

仿RabbitMq实现简易消息队列正式篇(路由匹配篇)

TOC 目录 路由匹配模块 代码展示 路由匹配模块 决定了一条消息是否能够发布到指定的队列 在每个队列根交换机的绑定信息中&#xff0c;都有一个binding_key&#xff08;在虚拟机篇有说到&#xff09;这是队列发布的匹配规则 在每条要发布的消息中&#xff0c;都有一个rout…

2024.8完善版 NineAi-ChatGPT系统源码

Nine AI.ChatGPT是基于ChatGPT开发的一个人工智能技术驱动的自然语言处理工具&#xff0c;它能够通过学习和理解人类的语言来进行对话&#xff0c;还能根据聊天的上下文进行互动&#xff0c;真正像人类一样来聊天交流&#xff0c;甚至能完成撰写邮件、视频脚本、文案、翻译、代…

老师简单易用的分班查询神器

随着新学期的到来&#xff0c;学校和老师们面临着一项重要任务&#xff1a;确保每位家长都能及时准确地获取孩子的分班信息。传统的分班通知方式&#xff0c;如纸质通知、校园公告板或电话短信通知&#xff0c;不仅效率低下&#xff0c;而且容易出错。为了解决这一问题&#xf…