G2 基于生成对抗网络(GAN)人脸图像生成

news/2024/11/2 19:24:38/
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

基于生成对抗网络(GAN)人脸图像生成

这周将构建并训练一个生成对抗网络(GAN)来生成人脸图像。

GAN 原理概述

生成对抗网络通过两个神经网络的对抗性结构来实现目标:

  • 生成器(G):输入随机噪声,通过学习数据的分布模式生成类似真实图像的输出。
  • 判别器(D):用来判断输入的图像是真实的还是生成器生成的。

训练过程中,生成器尝试欺骗判别器,生成逼真的图像,而判别器则不断优化,以区分真实图像与生成图像。这种对抗过程最终使生成器的生成能力逐渐逼近真实图像。

环境准备

首先导入相关库并设置随机种子以确保结果的可复现性。

python">import random
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import numpy as np

超参数设置

在训练GAN之前,首先定义一些关键的超参数:

  • batch_size:每个批次的样本数。
  • image_size:图像的大小,用于调整输入数据的尺寸。
  • nz:潜在向量大小,即生成器的输入维度。
  • ngfndf:分别控制生成器和判别器中的特征图数量。
  • num_epochs:训练的总轮数。
  • lr:学习率。
python">batch_size = 128
image_size = 64
nz = 100
ngf = 64
ndf = 64
num_epochs = 50
lr = 0.0002
beta1 = 0.5

数据加载

通过torchvision.datasets.ImageFolder加载数据,并使用 torch.utils.data.DataLoader 进行批量处理。数据加载时,通过转换函数调整图像大小,并对其进行归一化处理。

python">dataroot = "data/GANdata"
dataset = dset.ImageFolder(root=dataroot,transform=transforms.Compose([transforms.Resize(image_size),transforms.CenterCrop(image_size),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

网络结构定义

1. 生成器

生成器将随机噪声(潜在向量)通过一系列转置卷积层转换为图像。每层使用ReLU激活函数,最后一层用Tanh激活函数,将输出限制在 [-1, 1]

python">class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.main = nn.Sequential(nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),nn.BatchNorm2d(ngf * 8),nn.ReLU(True),nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 4),nn.ReLU(True),nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 2),nn.ReLU(True),nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf),nn.ReLU(True),nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False),nn.Tanh())def forward(self, input):return self.main(input)

2. 判别器

判别器为卷积网络,通过一系列卷积层提取图像特征。每层使用LeakyReLU激活函数,最终输出一个值(真实为1,生成为0)。

python">class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.main = nn.Sequential(nn.Conv2d(3, ndf, 4, 2, 1, bias=False),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 2),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 4),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),nn.Sigmoid())def forward(self, input):return self.main(input)

训练过程

训练分为两个部分:判别器和生成器的更新。

1. 判别器的训练

判别器首先接收真实图像样本,计算输出与真实标签的误差。然后判别器接收生成器生成的假图像,再计算输出与假标签的误差。最终判别器的损失是两者的总和。

python">output = netD(real_cpu).view(-1)
errD_real = criterion(output, label)
errD_real.backward()fake = netG(noise)
output = netD(fake.detach()).view(-1)
errD_fake = criterion(output, label.fill_(fake_label))
errD_fake.backward()

2. 生成器的训练

生成器的目标是欺骗判别器,因此其损失函数基于判别器将生成图像误识为真实的概率值。

python">output = netD(fake).view(-1)
errG = criterion(output, label.fill_(real_label))
errG.backward()

训练监控与可视化

在这里插入图片描述

训练时,我们记录生成器和判别器的损失,并生成一些样本图像来查看生成器的效果。

python">plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G")
plt.plot(D_losses, label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig('Generator and Discriminator Loss During Training.png')

在这里插入图片描述

结果可视化

训练结束后,我们将真实图像与生成图像对比,以检验生成器的效果。

python">plt.figure(figsize=(15, 15))
plt.subplot(1, 2, 1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(), (1, 2, 0)))plt.subplot(1, 2, 2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1], (1, 2, 0)))
plt.savefig('Fake Images.png')
plt.show()

在这里插入图片描述

总结

这周学习构建了一个深度卷积生成对抗网络(DCGAN),用于生成逼真的人脸图像,通过这周学习对对抗网路的构建有了更深的了解与运用


http://www.ppmy.cn/news/1543950.html

相关文章

解决Java接口接受附件入参失败的问题

接口的入参实体类存在 MultipartFile 类型属性,接口入参注解为RequestBody,会报错。需要把入参注解改为RequestPart,或者去掉注解。 RequestBody和RequestPart的异同: 相同点: 都可以用实体类接收传参 不同点&#xff…

分布式锁(redisson,看门狗,主从一致性)

目录 分布式锁一:基本原理和实现方式二:分布式锁的实现1:分布式锁的误删问题2:解决误删问题 三:lua脚本解决多条命令原子性问题调用lua脚本 四:Redisson1:redisson入门2:redisson可重…

.bixi勒索病毒来袭:如何防止文件加密与数据丢失?

导言 在网络威胁剧烈的今天,勒索病毒已成为企业和个人面临的重大安全挑战,其中虫洞勒索病毒习得高强度的加密手段和急剧传播的特性引起关注。一旦感染,就会加密关键数据并索要赎金,导致数据无法访问并带来巨大的财务损失。更为严…

春之竞赛:Spring Boot大学生竞赛管理平台

1系统概述 1.1 研究背景 随着计算机技术的发展以及计算机网络的逐渐普及,互联网成为人们查找信息的重要场所,二十一世纪是信息的时代,所以信息的管理显得特别重要。因此,使用计算机来管理大学生竞赛管理系统的相关信息成为必然。开…

openEuler下配置openGauss环境图解

一、在openEuler中创建用户,并授予权限 # 创建用户 sudo adduser omm# 授予权限 chown omm /opt# 切换用户 su - omm 二、在openGauss官网找到openGauss极简版的软件包 openGauss软件 | openGauss下载 | openGauss软件包 | openGauss社区 右键立即下载&#xff0…

VSCode中安装RN相关的插件(一)

一、插件介绍 ES7 React / Redux / React-Native snippets v4.4.3 如图,在VSCode的插件商店里搜索该插件进行安装。 安装完插件以后,就可以通过简写指令生成一些代码片段了。具体有哪些简短指令,大家可以自己查看,这里举两个示例…

独立北斗定位智能安全帽、定位安全帽、单北斗执法记录仪

AIoT万物智联,智能安全帽生产厂家,执法记录仪生产厂家,单北斗定位智能安全帽、智能头盔、头盔记录仪、执法记录仪、智能视频分析/边缘计算AI盒子、车载视频监控/车载DVR/NVR、布控球、智能眼镜、智能手电、智能电子工牌、无人机4G补传系统等统…

介绍目标检测中mAP50和mAP50-95的区别

在目标检测任务中,mAP(mean Average Precision)是一个常用的性能评估指标,用于衡量模型在不同类别和不同IoU(Intersection over Union)阈值下的平均精度。mAP50和mAP50-95是mAP的两个特定版本,它…