【PyTorch】12 生成对抗网络实战——用GAN生成动漫头像

news/2024/11/8 14:48:32/

GAN 生成动漫头像

  • 1. 获取数据
  • 2. 用GAN生成
    • 2.1 Generator
    • 2.2 Discriminator
    • 2.3 其它细节
    • 2.4 训练思路
  • 3. 全部代码
  • 4. 结果展示与分析
  • 小结

深度卷积生成对抗网络(DCGAN):Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks

1. 获取数据

原来书里的下载链接即原来知乎的何之源分享链接失效了,一篇简书里有下载地址,一共是51223张图片,尺寸是96×96×3,总大小272 MB

利用python查看图片的大小:

法I:

from PIL import Image
image = Image.open(dir + path[0])
imgSize = image.size  #大小
print(imgSize)
(96, 96)

法II:

import cv2
img = cv2.imread(dir + path[0])
sp = img.shape
print(sp)
(96, 96, 3)

2. 用GAN生成

2.1 Generator

CONVTRANSPOSE2D

torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros')

官方文档,这是逆卷积,关于卷积可看之前的CNN猫狗二分类

在由多个输入平面组成的输入图像上应用二维转置卷积算子

这个模块可以看作是Conv2d相对于其输入的梯度。它也被称为分式卷积或解卷积(虽然它不是实际的解卷操作)

  • stride控制交叉相关的步幅
  • padding控制两边隐含的零填充量,以便进行dilation * (kernel_size - 1) - padding。详情请看下面的说明
  • output_padding 控制添加到输出形状一侧的额外尺寸。详情请看下面的说明
  • dilation控制核点之间的间距,也就是所谓的à trous算法。它比较难描述,但这个链接有一个很好的可视化的扩张作用
  • groups控制输入和输出之间的连接,in_channels和out_channels都必须被分组所除

对于本实验:

  • 输入维度:noiseSize × 1 × 1
  • kernel_size=4, stride=1, padding=0,第一次变化后:(n_generator_feature * 8) × 4 × 4
  • 当kernel_size=4, stride=2, padding=1时,输入的宽高刚好是第一次的两倍,第二次变化后:(n_generator_feature * 4) × 8 × 8
  • 第三次变化后:(n_generator_feature * 2) × 16 × 16
  • 第四次变化后:n_generator_feature × 32 × 32
  • 最后一层采用kernel_size=5, stride=3, padding=1,为了将32 × 32变为96 × 96
  • 最后用Tanh将输出图片的像素归一化到-1~1,如果希望归一化到0~1,需要使用Sigmoid
Generator = NetGenerator()
x = torch.rand(1, noiseSize, 1, 1)
y = Generator(x)
print(y.size())
torch.Size([1, 3, 96, 96])

2.2 Discriminator

LeakyReLU

官方手册,inplace=True表示进行覆盖运算

Discriminator = NetDiscriminator()
x = torch.rand(1, 3, 96, 96)
y = Discriminator(x)
print(y.size())
torch.Size([1])

可以看出判别器和生成器的额网络几乎是对称的,从卷积核大小到padding、stride等设置,需要注意的是生成器的激活函数是ReLU,而判别器使用的是LeakyLeRU,二者并无本质区别,这里的选择更多的是经验总结。每一个样本经过判别器后,输出一个0~1的数,表示的是这个样本是真图片的概率

2.3 其它细节

torch.utils.data.DataLoader官方文档

torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)中文文档
betas (Tuple[float, float], 可选) – 平滑常数:用于计算梯度以及梯度平方的运行平均值的系数(默认:0.9,0.999)

torch.nn.BCELoss():计算target 和output 间的二值交叉熵(Binary Cross Entropy)官方文档,计算公式可见此

关于(tqdm.tqdm)可见官方文档

2.4 训练思路

  • 训练判别器
    • 对于真图片,输出尽可能是1
    • 对于假图片,输出尽可能是0
  • 训练生成器
    • 对于假图片,输出尽可能是1

这里需要注意以下几点。

  • 训练生成器时,无须调整判别器的参数;训练判别器时,无须调整生成器的参数。
  • 在训练判别器时,需要对生成器生成的图片用detach操作进行计算图截断,避免反向传播将梯度传到生成器中。因为在训练判别器时我们不需要训练生成器,也就不需要生成器的梯度。
  • 在训练判别器时,需要反向传播两次,一次是希望把真图片判为1,一次是希望把假图片判为0。也可以将这两者的数据放到一个batch中,进行一次前向传播和一次反向传播即可。但是人们发现,在一个batch中只包含真图片或只包含假图片的做法最好。
  • 对于假图片,在训练判别器时,我们希望它输出0;而在训练生成器时,我们希望它输出1.因此可以看到一对看似矛盾的代码 error_d_fake = criterion(output, fake_labels)和error_g = criterion(output, true_labels)。其实这也很好理解,判别器希望能够把假图片判别为fake_label,而生成器则希望能把他判别为true_label,判别器和生成器互相对抗提升。

接下来就是一些可视化的代码。每次可视化使用的噪声都是固定的fix_noises,因为这样便于我们比较对于相同的输入,生成器生成的图片是如何一步步提升的。另外,由于我们对输入的图片进行了归一化处理(-1~1),在可视化时则需要将它还原成原来的scale(0~1)

3. 全部代码

# import os
import torch
import torch.nn as nn
import torchvision as tv
from torch.autograd import Variable
import tqdm
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 显示中文标签
plt.rcParams['axes.unicode_minus'] = False# dir = '... your path/faces/'
dir = '/mnt/Data1/ysc/GAN'
# path = []
#
# for fileName in os.listdir(dir):
#     path.append(fileName)       # len(path)=51223noiseSize = 100     # 噪声维度
n_generator_feature = 64        # 生成器feature map数
n_discriminator_feature = 64        # 判别器feature map数
batch_size = 256
d_every = 1     # 每一个batch训练一次discriminator
g_every = 5     # 每五个batch训练一次generatorclass NetGenerator(nn.Module):def __init__(self):super(NetGenerator,self).__init__()self.main = nn.Sequential(      # 神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行nn.ConvTranspose2d(noiseSize, n_generator_feature * 8, kernel_size=4, stride=1, padding=0, bias=False),nn.BatchNorm2d(n_generator_feature * 8),nn.ReLU(True),       # (n_generator_feature * 8) × 4 × 4        (1-1)*1+1*(4-1)+0+1 = 4nn.ConvTranspose2d(n_generator_feature * 8, n_generator_feature * 4, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_generator_feature * 4),nn.ReLU(True),      # (n_generator_feature * 4) × 8 × 8     (4-1)*2-2*1+1*(4-1)+0+1 = 8nn.ConvTranspose2d(n_generator_feature * 4, n_generator_feature * 2, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_generator_feature * 2),nn.ReLU(True),  # (n_generator_feature * 2) × 16 × 16nn.ConvTranspose2d(n_generator_feature * 2, n_generator_feature, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_generator_feature),nn.ReLU(True),      # (n_generator_feature) × 32 × 32nn.ConvTranspose2d(n_generator_feature, 3, kernel_size=5, stride=3, padding=1, bias=False),nn.Tanh()       # 3 * 96 * 96)def forward(self, input):return self.main(input)class NetDiscriminator(nn.Module):def __init__(self):super(NetDiscriminator,self).__init__()self.main = nn.Sequential(nn.Conv2d(3, n_discriminator_feature, kernel_size=5, stride=3, padding=1, bias=False),nn.LeakyReLU(0.2, inplace=True),        # n_discriminator_feature * 32 * 32nn.Conv2d(n_discriminator_feature, n_discriminator_feature * 2, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_discriminator_feature * 2),nn.LeakyReLU(0.2, inplace=True),         # (n_discriminator_feature*2) * 16 * 16nn.Conv2d(n_discriminator_feature * 2, n_discriminator_feature * 4, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_discriminator_feature * 4),nn.LeakyReLU(0.2, inplace=True),  # (n_discriminator_feature*4) * 8 * 8nn.Conv2d(n_discriminator_feature * 4, n_discriminator_feature * 8, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(n_discriminator_feature * 8),nn.LeakyReLU(0.2, inplace=True),  # (n_discriminator_feature*8) * 4 * 4nn.Conv2d(n_discriminator_feature * 8, 1, kernel_size=4, stride=1, padding=0, bias=False),nn.Sigmoid()        # 输出一个概率)def forward(self, input):return self.main(input).view(-1)def train():for i, (image,_) in tqdm.tqdm(enumerate(dataloader)):       # type((image,_)) = <class 'list'>, len((image,_)) = 2 * 256 * 3 * 96 * 96real_image = Variable(image)real_image = real_image.cuda()if (i + 1) % d_every == 0:optimizer_d.zero_grad()output = Discriminator(real_image)      # 尽可能把真图片判为Trueerror_d_real = criterion(output, true_labels)error_d_real.backward()noises.data.copy_(torch.randn(batch_size, noiseSize, 1, 1))fake_img = Generator(noises).detach()       # 根据噪声生成假图fake_output = Discriminator(fake_img)       # 尽可能把假图片判为Falseerror_d_fake = criterion(fake_output, fake_labels)error_d_fake.backward()optimizer_d.step()if (i + 1) % g_every == 0:optimizer_g.zero_grad()noises.data.copy_(torch.randn(batch_size, noiseSize, 1, 1))fake_img = Generator(noises)        # 这里没有detachfake_output = Discriminator(fake_img)       # 尽可能让Discriminator把假图片判为Trueerror_g = criterion(fake_output, true_labels)error_g.backward()optimizer_g.step()def show(num):fix_fake_imags = Generator(fix_noises)fix_fake_imags = fix_fake_imags.data.cpu()[:64] * 0.5 + 0.5# x = torch.rand(64, 3, 96, 96)fig = plt.figure(1)i = 1for image in fix_fake_imags:ax = fig.add_subplot(8, 8, eval('%d' % i))# plt.xticks([]), plt.yticks([])  # 去除坐标轴plt.axis('off')plt.imshow(image.permute(1, 2, 0))i += 1plt.subplots_adjust(left=None,  # the left side of the subplots of the figureright=None,  # the right side of the subplots of the figurebottom=None,  # the bottom of the subplots of the figuretop=None,  # the top of the subplots of the figurewspace=0.05,  # the amount of width reserved for blank space between subplotshspace=0.05)  # the amount of height reserved for white space between subplots)plt.suptitle('第%d迭代结果' % num, y=0.91, fontsize=15)plt.show()if __name__ == '__main__':transform = tv.transforms.Compose([tv.transforms.Resize(96),     # 图片尺寸, transforms.Scale transform is deprecatedtv.transforms.CenterCrop(96),tv.transforms.ToTensor(),tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))       # 变成[-1,1]的数])dataset = tv.datasets.ImageFolder(dir, transform=transform)dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)   # module 'torch.utils.data' has no attribute 'DataLoder'print('数据加载完毕!')Generator = NetGenerator()Discriminator = NetDiscriminator()optimizer_g = torch.optim.Adam(Generator.parameters(), lr=2e-4, betas=(0.5, 0.999))optimizer_d = torch.optim.Adam(Discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))criterion = torch.nn.BCELoss()true_labels = Variable(torch.ones(batch_size))     # batch_sizefake_labels = Variable(torch.zeros(batch_size))fix_noises = Variable(torch.randn(batch_size, noiseSize, 1, 1))noises = Variable(torch.randn(batch_size, noiseSize, 1, 1))     # 均值为0,方差为1的正态分布if torch.cuda.is_available() == True:print('Cuda is available!')Generator.cuda()Discriminator.cuda()criterion.cuda()true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()fix_noises, noises = fix_noises.cuda(), noises.cuda()plot_epoch = [1,5,10,20,50,100,199]for i in range(200):        # 最大迭代次数train()print('迭代次数:{}'.format(i))if i in plot_epoch:show(i)

4. 结果展示与分析

在第1,5,10,20,50,100,199分别打印结果如下所示,这里第0代没有打印:
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

  • 刚开始训练的图像比较模糊(1个epoch),但是可以看出图像已经有面部轮廓
  • 继续训练数个epoch之后,生成的图多了很多细节信息,包括头发、颜色等,但是总体还是模糊
  • 训练数个epoch之后,细节继续完善,包括头发的纹理、眼睛的细节等,但还是有不少涂抹的痕迹
  • 训练数个epoch时,已经能看出明显的面部轮廓和细节,但还是有涂抹现象,并且有些细节不够合理,例如眼睛一大一小,面部轮廓扭曲严重
  • 当训练到最大epoch会后,图片的细节已经十分完善,线条更加流畅,轮廓更清晰,虽然还有一些不合理之处,但是已经有不少图片能够以假乱真了

类似的生成动漫头像的项目还有《用DRGAN生成高清的动漫头像》,效果很好,但遗憾的是,由于论文中使用的数据涉及版权问题,未能公开。这篇论文主要改进包括使用了更高质量的图片和更深、更复杂的模型

GAN可以应用到不同的生成图片场景中,只要将训练图片改成其他类型的图片即可,例如LSUN房客图片集、MNIST手写数据集或CIFAR10数据集等。事实上,上述模型还有很大的改进空间。在这里,我们使用的全卷积网络只有四层,模型比较浅,而在ResNet的论文发表之后,也有不少研究者尝试在GAN的网络结构中引入Residual Block结构,并取得了不错的视觉效果。感兴趣可以尝试将示例代码中的单层卷积改为Residual Block,相信可以取得不错的效果

今年来,GAN的一个重大突破在于理论研究。论文《Towards Principled Methods for Training Generative Adversarial Networks》从理论的角度分析了GAN为何难以训练,作者随后在另一篇论文《Wasserstein GAN》中针对性地提出了一个更好的解决方案。但是这篇论文在部分技术细节上的实现过于随意,所以随后又有人有针对性地提出了《Improved Training of Wasserstein GANs》,更好地训练WGAN。后面两篇论文分别用PyTorch和TensorFlow实现,代码可以在GitHub上搜索到。笔者当初也尝试用100行左右的代码实现了Wasserstein GAN,该兴趣可以去了解

随着GAN研究的逐渐成熟,人们也尝试把GAN用于工业实际问题之中,而在众多相关论文中,最令人深刻的就是《Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks》,论文中提出了一种新的GAN结构称为CycleGAN。CycleGAN利用GAN实现风格迁移、黑白图像彩色化,以及马和斑马互相转化等,效果十分出众。论文的作者用PyTorch实现了所有的代码,并开源在GitHub上,感兴趣可以自行查阅

小结

GAN生成的结果还是比较理想吧,就是一个简单的GAN的结构,其中的Generator的反卷积与Discriminator的卷积可以琢磨一下,模型只是简单的训练,并没有保存和测试


http://www.ppmy.cn/news/709574.html

相关文章

航海王燃烧意志服务器响应格式非法,航海王燃烧意志充值异常怎么处理 航海王燃烧意志充值异常申诉方法_斗蟹游戏网...

【斗蟹-航海王燃烧意志】航海王燃烧意志游戏中玩家在充值后发现没有到账&#xff0c;那要在怎么处理充值异常&#xff0c;下面小编带大家一起看看航海王燃烧意志充值异常申诉方法&#xff0c;希望能在游戏中帮到大家。 在航海王燃烧意志手游中&#xff0c;不少朋友在充值后彩钻…

这顶海贼王的帽子,我Python给你带上了 | 【人脸识别应用】

微信公众号&#xff1a;AI算法与图像处理如果你觉得对你有帮助&#xff0c;欢迎分享和转发哈 https://zhuanlan.zhihu.com/p/32299758?utm_sourcewechat_session&utm_mediumsocial&utm_oi704056637840695296 内容目录 故事起因思路与实现准备工作详细代码和效果总结1.…

程序员的发展之道---海贼王(山治)

对于日本动漫&#xff0c;我唯一喜欢&#xff0c;也是一直在追的就是海贼王&#xff0c;尤其喜欢里面的厨师山治&#xff0c;至于为什么喜欢他&#xff0c;也许是因为他绅士&#xff0c;儒雅&#xff0c;对梦想的执着 当然这只是我的个人看法&#xff0c;但是就是这样一个没有超…

【数据结构与算法】Huffman编码/译码(C/C++)

实践要求 1. 问题描述 利用哈夫曼编码进行信息通讯可以大大提高信道利用率&#xff0c;缩短信息传输时间&#xff0c;降低传输成本。但是&#xff0c;这要求在发送端通过一个编码系统对待传数据预先编码&#xff1b;在接收端将传来的数据进行译码(复原)。对于双工信道(即可以…

Vuejs 3.0 正式版发布!One Piece. 代号:海贼王

文末有送书福利 译者&#xff1a;夜尽天明 &#xff08;译者授权转载&#xff09; 原文地址&#xff1a;https://mp.weixin.qq.com/s/0oet-MTo__LWNZNYl5Fpqw Vue 团队于 2020 年 9 月 18 日晚 11 点半发布了 Vue 3.0 版本。 那个男人总喜欢在深夜给我们带来意外惊喜&#xff0…

如何查看 MySQL 建表时间

MySQL是一款性能良好&#xff0c;易于使用的关系型数据库管理系统。我们可以使用 SQL 语句查看 MySQL 建表时间&#xff0c;以便获取建立表时的更多信息。 1、 首先&#xff0c;在MySQL中执行以下命令&#xff0c;获取表的列表&#xff1a; SELECT create_time,table_name FR…

资料下载链接

大家好&#xff01;这是我整理的免费视频教程以及电子书&#xff0c;每天都会有更新&#xff0c;希望对大家能有帮助。 与大家共勉&#xff01;大家可以根据自己感兴趣的方向浏览下载哦!祝大家事业有成&#xff01;学习进步&#xff01; Linux&#xff1a; LAMP兄弟连Linux视频…

POC,黑客精神的一场回响

生命不息&#xff0c;破解不止。 一年一度&#xff0c;黑客们聚集韩国首尔&#xff0c;分享自己对这个世界最新的理解。 这就是 POC。 如果说 POC 黑客大会和其他黑客大会有怎样的区别&#xff0c;除了随处可见的韩国妹子之外&#xff0c;就是会上透露出来的浓浓的“Pwn”的气息…