[Day 44] 區塊鏈與人工智能的聯動應用:理論、技術與實踐

ops/2024/9/22 15:46:49/

生成对抗网络(Generative Adversarial Networks,GANs)是一种由Ian Goodfellow等人在2014年提出的深度学习模型,广泛用于图像生成、图像超分辨率、图像修复等领域。GAN由一个生成器(Generator)和一个判别器(Discriminator)组成,二者通过对抗训练相互提升性能。以下是关于GAN的详细介绍和代码实现示例。

一、生成对抗网络的原理

1.1 生成器(Generator)

生成器的目标是生成逼真的样本,使得判别器无法区分生成样本和真实样本。生成器接收一个随机噪声向量(通常为高斯分布或均匀分布),通过一系列的神经网络层转换成逼真的数据样本。

1.2 判别器(Discriminator)

判别器的目标是将真实样本和生成样本区分开来。判别器是一个二分类模型,输入为样本数据,输出为分类概率,表示输入样本是“真实”还是“生成”的概率。

1.3 对抗训练

生成器和判别器通过对抗训练来提升彼此的能力。生成器试图欺骗判别器,而判别器不断提升自己的判别能力。二者的目标函数如下:

  • 生成器的损失函数:使得生成样本被判别器判断为真实样本的概率最大。
  • 判别器的损失函数:最大化判别真实样本和生成样本的能力。

具体的数学表达式如下:

min_{G} max_{D}V(D,G)E_{x\sim p_{data}(x)}[logD(x)]+E_{x\sim p_{z}(z)}[log(1 - D(G((z)))]

二、代码实现

我们将以MNIST数据集为例,使用Keras实现一个简单的GAN模型。

2.1 导入必要的库

import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import mnist
from keras.models import Sequential, Model
from keras.layers import Dense, LeakyReLU, BatchNormalization, Reshape, Flatten, Input
from keras.optimizers import Adam

2.2 数据预处理

加载并预处理MNIST数据集,使其适用于GAN的输入。

# 加载MNIST数据集
(X_train, _), (_, _) = mnist.load_data()# 归一化并reshape数据
X_train = X_train / 127.5 - 1.0
X_train = np.expand_dims(X_train, axis=3)# 输入维度
img_shape = X_train.shape[1:]
z_dim = 100  # 噪声向量维度

2.3 构建生成器

生成器将噪声向量转换为逼真的图像。我们使用全连接层和转置卷积层实现这一过程。

def build_generator(z_dim):model = Sequential()model.add(Dense(256, input_dim=z_dim))model.add(LeakyReLU(alpha=0.01))model.add(BatchNormalization(momentum=0.8))model.add(Dense(512))model.add(LeakyReLU(alpha=0.01))model.add(BatchNormalization(momentum=0.8))model.add(Dense(1024))model.add(LeakyReLU(alpha=0.01))model.add(BatchNormalization(momentum=0.8))model.add(Dense(np.prod(img_shape), activation='tanh'))model.add(Reshape(img_shape))return modelgenerator = build_generator(z_dim)
generator.summary()

2.4 构建判别器

判别器将输入图像分类为真实或生成的。我们使用卷积层和全连接层实现这一过程。

def build_discriminator(img_shape):model = Sequential()model.add(Flatten(input_shape=img_shape))model.add(Dense(512))model.add(LeakyReLU(alpha=0.01))model.add(Dense(256))model.add(LeakyReLU(alpha=0.01))model.add(Dense(1, activation='sigmoid'))return modeldiscriminator = build_discriminator(img_shape)
discriminator.summary()

2.5 编译模型

我们为生成器和判别器选择优化器,并编译判别器。

# 优化器
optimizer = Adam(0.0002, 0.5)# 编译判别器
discriminator.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])

2.6 构建GAN模型

我们将生成器和判别器结合起来,构建完整的GAN模型,并编译生成器。

# 构建生成器
z = Input(shape=(z_dim,))
img = generator(z)# 将判别器设置为不可训练,仅训练生成器
discriminator.trainable = False# 判别器预测生成图像
validity = discriminator(img)# 构建GAN模型
gan = Model(z, validity)
gan.compile(loss='binary_crossentropy', optimizer=optimizer)gan.summary()

2.7 训练模型

我们定义训练过程,包括生成器和判别器的训练步骤。

def train(epochs, batch_size=128, sample_interval=100):# 加载数据(X_train, _), (_, _) = mnist.load_data()X_train = X_train / 127.5 - 1.0X_train = np.expand_dims(X_train, axis=3)# 真实标签real = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))for epoch in range(epochs):# ---------------------# 训练判别器# ---------------------# 随机选择真实图像idx = np.random.randint(0, X_train.shape[0], batch_size)imgs = X_train[idx]# 生成噪声并生成假图像z = np.random.normal(0, 1, (batch_size, z_dim))gen_imgs = generator.predict(z)# 训练判别器d_loss_real = discriminator.train_on_batch(imgs, real)d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# ---------------------# 训练生成器# ---------------------z = np.random.normal(0, 1, (batch_size, z_dim))g_loss = gan.train_on_batch(z, real)# 打印进度print(f"{epoch} [D loss: {d_loss[0]} | D accuracy: {100 * d_loss[1]}] [G loss: {g_loss}]")# 每隔sample_interval保存生成的图像样本if epoch % sample_interval == 0:sample_images(epoch)def sample_images(epoch, image_grid_rows=4, image_grid_columns=4):z = np.random.normal(0, 1, (image_grid_rows * image_grid_columns, z_dim))gen_imgs = generator.predict(z)gen_imgs = 0.5 * gen_imgs + 0.5fig, axs = plt.subplots(image_grid_rows, image_grid_columns, figsize=(4, 4), sharey=True, sharex=True)cnt = 0for i in range(image_grid_rows):for j in range(image_grid_columns):axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')axs[i, j].axis('off')cnt += 1plt.show()

2.8 开始训练

我们设置训练参数并开始训练GAN模型。

epochs = 10000
batch_size = 64
sample_interval = 1000train(epochs, batch_size, sample_interval)

2.9 详细解释代码

导入库

我们导入了Keras和其他必要的库,用于构建和训练我们的GAN模型。

import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import mnist
from keras.models import Sequential, Model
from keras.layers import Dense, LeakyReLU, BatchNormalization, Reshape, Flatten, Input
from keras.optimizers import Adam
数据预处理

我们加载MNIST数据集,并对图像进行归一化处理,将其范围调整到[-1, 1],以便于GAN的训练。

(X_train, _), (_, _) = mnist.load_data()
X_train = X_train / 127.5 - 1.0
X_train = np.expand_dims(X_train, axis=3)
img_shape = X_train.shape[1:]
z_dim = 100
构建生成器

生成器将噪声向量转换为逼真的图像。我们使用了全连接层、LeakyReLU激活函数和批归一化层来实现这一过程。

def build_generator(z_dim):model = Sequential()model.add(Dense(256, input_dim=z_dim))model.add(LeakyReLU(alpha=0.01))model.add(BatchNormalization(momentum=0.8))model.add(Dense(512))model.add(LeakyReLU(alpha=0.01))model.add(BatchNormalization(momentum=0.8))model.add(Dense(1024))model.add(LeakyReLU(alpha=0.01))model.add(BatchNormalization(momentum=0.8))model.add(Dense(np.prod(img_shape), activation='tanh'))model.add(Reshape(img_shape))return model
构建判别器

判别器将输入图像分类为真实或生成的。我们使用了卷积层、LeakyReLU激活函数和全连接层来实现这一过程。

def build_discriminator(img_shape):model = Sequential()model.add(Flatten(input_shape=img_shape))model.add(Dense(512))model.add(LeakyReLU(alpha=0.01))model.add(Dense(256))model.add(LeakyReLU(alpha=0.01))model.add(Dense(1, activation='sigmoid'))return model
编译模型

我们为生成器和判别器选择优化器,并编译判别器。

optimizer = Adam(0.0002, 0.5)
discriminator.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])
构建GAN模型

我们将生成器和判别器结合起来,构建完整的GAN模型,并编译生成器。

z = Input(shape=(z_dim,))
img = generator(z)
discriminator.trainable = False
validity = discriminator(img)
gan = Model(z, validity)
gan.compile(loss='binary_crossentropy', optimizer=optimizer)
训练模型

我们定义训练过程,包括生成器和判别器的训练步骤。

def train(epochs, batch_size=128, sample_interval=100):(X_train, _), (_, _) = mnist.load_data()X_train = X_train / 127.5 - 1.0X_train = np.expand_dims(X_train, axis=3)real = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))for epoch in range(epochs):idx = np.random.randint(0, X_train.shape[0], batch_size)imgs = X_train[idx]z = np.random.normal(0, 1, (batch_size, z_dim))gen_imgs = generator.predict(z)d_loss_real = discriminator.train_on_batch(imgs, real)d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)z = np.random.normal(0, 1, (batch_size, z_dim))g_loss = gan.train_on_batch(z, real)print(f"{epoch} [D loss: {d_loss[0]} | D accuracy: {100 * d_loss[1]}] [G loss: {g_loss}]")if epoch % sample_interval == 0:sample_images(epoch)def sample_images(epoch, image_grid_rows=4, image_grid_columns=4):z = np.random.normal(0, 1, (image_grid_rows * image_grid_columns, z_dim))gen_imgs = generator.predict(z)gen_imgs = 0.5 * gen_imgs + 0.5fig, axs = plt.subplots(image_grid_rows, image_grid_columns, figsize=(4, 4), sharey=True, sharex=True)cnt = 0for i in range(image_grid_rows):for j in range(image_grid_columns):axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')axs[i, j].axis('off')cnt += 1plt.show()
开始训练

我们设置训练参数并开始训练GAN模型。

epochs = 10000
batch_size = 64
sample_interval = 1000train(epochs, batch_size, sample_interval)

三、总结

通过以上代码和详细解释,我们实现了一个简单的生成对抗网络模型,并通过训练使生成器能够生成逼真的MNIST手写数字图像。GANs在许多领域有着广泛的应用,本文只是一个起步,读者可以进一步探索其在图像超分辨率、图像修复、文本生成等方面的应用。


http://www.ppmy.cn/ops/89680.html

相关文章

举例说明计算机视觉(CV)技术的优势和挑战

计算机视觉(CV)技术是一种利用计算机算法和技术来解析和理解图像和视频数据的领域。它的优势和挑战如下: 优势: 高速处理:计算机视觉可以快速处理大量的图像和视频数据,使得它在实时应用中非常有用。例如&…

认识VO、DTO、Entity

关于VO、DTO、Entity 概念 VO(View Object):视图对象,专门用于前端展示层,专注于表示某个具体的值或对象的对象,包含业务逻辑;VO的作用是将一组数据以适合特定用户界面(UI&#xf…

大数据应用【大数据导论】

各位大佬好 ,这里是阿川的博客,祝您变得更强 个人主页:在线OJ的阿川 大佬的支持和鼓励,将是我成长路上最大的动力 阿川水平有限,如有错误,欢迎大佬指正 目录 大数据在许多领域应用互联网领域应用生物医学…

HarmonyOS NEXT——奇妙的调用方式

注解调用一句话总结Extend抽取特定组件样式、事件,可以传递参数Style抽取公共样式、事件,不可以传递参数Builder抽取结构、样式、事件,可以传递参数BuilderParams自定义组件中传递UI组件多个BuilderParams自定义组件中传递多个UI组件 Extend…

网络安全大模型开源项目有哪些?

01 Ret2GPT 它是面向CTF二进制安全的工具,结合ChatGPT API、Retdec和Langchain进行漏洞挖掘,它能通过问答或预设Prompt对二进制文件进行分析。 https://github.com/DDizzzy79/Ret2GPT 02 OpenAI Codex 它是基于GPT-3.5-turbo模型,用于编写…

【Python系列】Python 协程:并发编程的新篇章

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

如何在 Python 中测试文件修改

在我日常编程中,如果想在Python中测试文件的修改,我这里总结出有多种方式。其中使用 os.path.getmtime() 函数可以获取文件的最后修改时间戳,然后可以定期检查文件是否有更新。这种方法适合于轮询检查文件是否修改。这种方法是我最常用的。 问…

分享5款.NET开源免费的Redis客户端组件库

前言 今天大姚给大家分享5款.NET开源、免费的Redis客户端组件库,希望可以帮助到有需要的同学。 StackExchange.Redis StackExchange.Redis是一个基于.NET的高性能Redis客户端,提供了完整的Redis数据库功能支持,并且具有多节点支持、异步编…