Pixel2Pixel:卡通照片真人化

news/2024/10/17 17:27:39/

Pixel2Pixel:卡通照片真人化

项目链接:Pixel2Pixel:卡通照片真人化

前言:
之前PaddleGAN的趣味应用如雨后春笋般地出现,非常多的项目都是xxx动漫化。当时就有一个很普通的想法为什么大家都会去搞动漫化,这很可能是因为二次元文化的原因,又或者是动漫化的应用、商业价值。就突然蹦出一个想法,为什么没人弄动漫真人化呢,然后我就去项目搜了,结果确实貌似没有人做这个项目。刚开始我以为我这个想法实现起来很难,到后面和大神们讨论后,其实觉得实现原理也很简单,就是把人像动漫化的数据集里面的标签互换。比如人像卡通化,就是A to B(A是真人,B是动漫,B是标签)。那么此次这个项目卡通人像化就是B to A(A是真人,B是动漫,A是标签).

先来看看实现效果

实现效果:

真人原图:

实现效果:

真人原图:

可以看到效果已经很逼真了!

1.下载安装包

import paddle
import paddle.nn as nn
from paddle.io import Dataset, DataLoaderimport os
import cv2
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt%matplotlib inline

2.解压数据

数据准备:

  • 真人数据来自seeprettyface。
  • 数据预处理(详情见photo2cartoon项目)。
  • 使用photo2cartoon项目生成真人数据对应的卡通数据。
# 解压数据
!unzip -q data/data79149/cartoon_A2B.zip -d data/

3.数据可视化(已划分好数据集)

# 训练数据统计
train_names = os.listdir('data/cartoon_A2B/train')
print(f'训练集数据量: {len(train_names)}')# 测试数据统计
test_names = os.listdir('data/cartoon_A2B/test')
print(f'测试集数据量: {len(test_names)}')# 训练数据可视化
imgs = []
for img_name in np.random.choice(train_names, 3, replace=False):imgs.append(cv2.imread('data/cartoon_A2B/train/'+img_name))img_show = np.vstack(imgs)[:,:,::-1]
plt.figure(figsize=(10, 10))
plt.imshow(img_show)
plt.show()

注意:

A代表真人,B代表卡通。源参考代码 是A to B。本次实验项目是用 B to A

又因为数据集是把 真人照片和卡通图片拼接在一起,利用划分宽度来区别原图与标签。例如源程序 是用 宽度[ : 256]分成真人(即原图),[256 : ]分成卡通(即标签)

要实现这个项目因此要把他们调换过来。

class PairedData(Dataset):def __init__(self, phase):super(PairedData, self).__init__() self.img_path_list = self.load_A2B_data(phase)    # 获取数据列表self.num_samples = len(self.img_path_list)        # 数据量def __getitem__(self, idx):img_A2B = cv2.imread(self.img_path_list[idx])     # 读取数据img_A2B = img_A2B.astype('float32') / 127.5 - 1.  # 归一化img_A2B = img_A2B.transpose(2, 0, 1)              # HWC -> CHWimg_A = img_A2B[..., 256:]                        # 卡通图(原图)   img_B = img_A2B[..., :256]                        # 真人图(标签)return img_A, img_Bdef __len__(self):return self.num_samples@staticmethoddef load_A2B_data(phase):assert phase in ['train', 'test'], "phase should be set within ['train', 'test']"# 读取数据集,数据中每张图像包含照片和对应的卡通画。data_path = 'data/cartoon_A2B/'+phasereturn [os.path.join(data_path, x) for x in os.listdir(data_path)]
paired_dataset_train = PairedData('train')
paired_dataset_test = PairedData('test')

4.定义生成器

class UnetGenerator(nn.Layer):def __init__(self, input_nc=3, output_nc=3, ngf=64):super(UnetGenerator, self).__init__()self.down1 = nn.Conv2D(input_nc, ngf, kernel_size=4, stride=2, padding=1)self.down2 = Downsample(ngf, ngf*2)self.down3 = Downsample(ngf*2, ngf*4)self.down4 = Downsample(ngf*4, ngf*8)self.down5 = Downsample(ngf*8, ngf*8)self.down6 = Downsample(ngf*8, ngf*8)self.down7 = Downsample(ngf*8, ngf*8)self.center = Downsample(ngf*8, ngf*8)self.up7 = Upsample(ngf*8, ngf*8, use_dropout=True)self.up6 = Upsample(ngf*8*2, ngf*8, use_dropout=True)self.up5 = Upsample(ngf*8*2, ngf*8, use_dropout=True)self.up4 = Upsample(ngf*8*2, ngf*8)self.up3 = Upsample(ngf*8*2, ngf*4)self.up2 = Upsample(ngf*4*2, ngf*2)self.up1 = Upsample(ngf*2*2, ngf)self.output_block = nn.Sequential(nn.ReLU(),nn.Conv2DTranspose(ngf*2, output_nc, kernel_size=4, stride=2, padding=1),nn.Tanh())def forward(self, x):d1 = self.down1(x)d2 = self.down2(d1)d3 = self.down3(d2)d4 = self.down4(d3)d5 = self.down5(d4)d6 = self.down6(d5)d7 = self.down7(d6)c = self.center(d7)x = self.up7(c, d7)x = self.up6(x, d6)x = self.up5(x, d5)x = self.up4(x, d4)x = self.up3(x, d3)x = self.up2(x, d2)x = self.up1(x, d1)x = self.output_block(x)return xclass Downsample(nn.Layer):# LeakyReLU => conv => batch normdef __init__(self, in_dim, out_dim, kernel_size=4, stride=2, padding=1):super(Downsample, self).__init__()self.layers = nn.Sequential(nn.LeakyReLU(0.2),nn.Conv2D(in_dim, out_dim, kernel_size, stride, padding, bias_attr=False),nn.BatchNorm2D(out_dim))def forward(self, x):x = self.layers(x)return xclass Upsample(nn.Layer):# ReLU => deconv => batch norm => dropoutdef __init__(self, in_dim, out_dim, kernel_size=4, stride=2, padding=1, use_dropout=False):super(Upsample, self).__init__()sequence = [nn.ReLU(),nn.Conv2DTranspose(in_dim, out_dim, kernel_size, stride, padding, bias_attr=False),nn.BatchNorm2D(out_dim)]if use_dropout:sequence.append(nn.Dropout(p=0.5))self.layers = nn.Sequential(*sequence)def forward(self, x, skip):x = self.layers(x)x = paddle.concat([x, skip], axis=1)return x

5.定义鉴别器

class NLayerDiscriminator(nn.Layer):def __init__(self, input_nc=6, ndf=64):super(NLayerDiscriminator, self).__init__()self.layers = nn.Sequential(nn.Conv2D(input_nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2),ConvBlock(ndf, ndf*2),ConvBlock(ndf*2, ndf*4),ConvBlock(ndf*4, ndf*8, stride=1),nn.Conv2D(ndf*8, 1, kernel_size=4, stride=1, padding=1),nn.Sigmoid())def forward(self, input):return self.layers(input)class ConvBlock(nn.Layer):# conv => batch norm => LeakyReLUdef __init__(self, in_dim, out_dim, kernel_size=4, stride=2, padding=1):super(ConvBlock, self).__init__()self.layers = nn.Sequential(nn.Conv2D(in_dim, out_dim, kernel_size, stride, padding, bias_attr=False),nn.BatchNorm2D(out_dim),nn.LeakyReLU(0.2))def forward(self, x):x = self.layers(x)return x

实例化生成器,鉴别器

generator = UnetGenerator()
discriminator = NLayerDiscriminator()
out = generator(paddle.ones([1, 3, 256, 256]))
print('生成器输出尺寸:', out.shape)out = discriminator(paddle.ones([1, 6, 256, 256]))
print('鉴别器输出尺寸:', out.shape)

6.定义训练各项超参数

# 超参数
LR = 1e-4
BATCH_SIZE = 8
EPOCHS = 100# 优化器
optimizerG = paddle.optimizer.Adam(learning_rate=LR,parameters=generator.parameters(),beta1=0.5,beta2=0.999)optimizerD = paddle.optimizer.Adam(learning_rate=LR,parameters=discriminator.parameters(), beta1=0.5,beta2=0.999)# 损失函数
bce_loss = nn.BCELoss()
l1_loss = nn.L1Loss()# dataloader
data_loader_train = DataLoader(paired_dataset_train,batch_size=BATCH_SIZE,shuffle=True,drop_last=True)data_loader_test = DataLoader(paired_dataset_test,batch_size=BATCH_SIZE)

训练效果

第一列是卡通(原图),第二列是真人图片(标签),第三列是学习出来的结果

刚开始学到的效果:

100epochs的效果:

我们可以看出已经有很好的效果

results_save_path = 'work/results'
os.makedirs(results_save_path, exist_ok=True)  # 保存每个epoch的测试结果weights_save_path = 'work/weights'
os.makedirs(weights_save_path, exist_ok=True)  # 保存模型for epoch in range(EPOCHS):for data in tqdm(data_loader_train):real_A, real_B = dataoptimizerD.clear_grad()# D(real)real_AB = paddle.concat((real_A, real_B), 1)d_real_predict = discriminator(real_AB)d_real_loss = bce_loss(d_real_predict, paddle.ones_like(d_real_predict))# D(fake)fake_B = generator(real_A).detach()fake_AB = paddle.concat((real_A, fake_B), 1)d_fake_predict = discriminator(fake_AB)d_fake_loss = bce_loss(d_fake_predict, paddle.zeros_like(d_fake_predict))# train Dd_loss = (d_real_loss + d_fake_loss) / 2.d_loss.backward()optimizerD.step()optimizerG.clear_grad()# D(fake)fake_B = generator(real_A)fake_AB = paddle.concat((real_A, fake_B), 1)g_fake_predict = discriminator(fake_AB)g_bce_loss = bce_loss(g_fake_predict, paddle.ones_like(g_fake_predict))g_l1_loss = l1_loss(fake_B, real_B) * 100.g_loss = g_bce_loss + g_l1_loss# train Gg_loss.backward()optimizerG.step()print(f'Epoch [{epoch+1}/{EPOCHS}] Loss D: {d_loss.numpy()}, Loss G: {g_loss.numpy()}')if (epoch+1) % 10 == 0:paddle.save(generator.state_dict(), os.path.join(weights_save_path, 'epoch'+str(epoch+1).zfill(3)+'.pdparams'))# testgenerator.eval()with paddle.no_grad():for data in data_loader_test:real_A, real_B = databreakfake_B = generator(real_A)result = paddle.concat([real_A[:3], real_B[:3], fake_B[:3]], 3)result = result.detach().numpy().transpose(0, 2, 3, 1)result = np.vstack(result)result = (result * 127.5 + 127.5).astype(np.uint8)cv2.imwrite(os.path.join(results_save_path, 'epoch'+str(epoch+1).zfill(3)+'.png'), result)generator.train()

7.测试

# 为生成器加载权重
last_weights_path = os.path.join(weights_save_path, sorted(os.listdir(weights_save_path))[-1])
print('加载权重:', last_weights_path)model_state_dict = paddle.load(last_weights_path)
generator.load_dict(model_state_dict)
generator.eval()
读取数据
test_names = os.listdir('data/cartoon_A2B/test')
# img_name = np.random.choice(test_names)
img_name = '01481.png'
img_A2B = cv2.imread('data/cartoon_A2B/test/'+img_name)
img_A = img_A2B[:, 256:]                                  # 卡通图(即输入)
img_B = img_A2B[:, :256]                                  # 真人图(即预测结果)# img_A= cv2.imread('data/test4.png')
# img_A = img_A[:, 256:]g_input = img_A.astype('float32') / 127.5 - 1             # 归一化
g_input = g_input[np.newaxis, ...].transpose(0, 3, 1, 2)  # NHWC -> NCHW
g_input = paddle.to_tensor(g_input)                       # numpy -> tensorg_output = generator(g_input)
g_output = g_output.detach().numpy()                      # tensor -> numpy
g_output = g_output.transpose(0, 2, 3, 1)[0]              # NCHW -> NHWC
g_output = g_output * 127.5 + 127.5                       # 反归一化
g_output = g_output.astype(np.uint8)img_show = np.hstack([img_A, g_output])[:,:,::-1]
plt.figure(figsize=(8, 8))
plt.imshow(img_show)numpy
g_output = g_output.transpose(0, 2, 3, 1)[0]              # NCHW -> NHWC
g_output = g_output * 127.5 + 127.5                       # 反归一化
g_output = g_output.astype(np.uint8)img_show = np.hstack([img_A, g_output])[:,:,::-1]
plt.figure(figsize=(8, 8))
plt.imshow(img_show)
plt.show()

总结:

至此,动漫照片真人化项目就完成了,本次项目大部分基于参考项目,只是做了些许改动。

参考项目:

Pixel2Pixel:人像卡通化

关于作者

  • 个人主页
  • 感兴趣的方向为:目标检测,图像分类,图像分割等。
  • 不定期更新感兴趣的CV比赛baseline等
  • 个人荣誉:飞桨开发者技术专家(PPDE)
  • 欢迎大家有问题留言交流学习,共同进步成长。

http://www.ppmy.cn/news/465093.html

相关文章

ChatGPT或致全球3亿人失业,人工智能时代下教育会发生什么样的变革?

不久前,谷歌教育发布了一份关于未来教育的研究报告。该报告由谷歌公司和 Canvas8合作,对来自世界24个国家的94位教育专家进行了历时长达2年的调研,探讨了未来教育形态、教育在未来的作用、教育公平、全球人才需求、教学方式、学习生态、工作技…

【STM32】软件I2C

【STM32】软件I2C I2C简介 I2C总线是一种串行、半双工的总线,主要用于近距离、低速的芯片之间的通信。I2C总线有两根双向的信号线,一根数据线SDA用于收发数据,一根时钟线SCL用于通信双方时钟的同步。 在一个i2c通讯总线中,可连接…

SQL Server处于恢复已挂起状态的解决方法

win r 输入 services.msc,进入服务!找到SQL Server(MSSQLSERVER)点击重启

虚拟机状态显示“已关机”,但是提示“该虚拟机似乎正在使用中”

问题:想要删除虚拟机时,虚拟机状态显示“已关机”,但是提示“该虚拟机似乎正在使用中” 解决:删除虚拟机安装目录下的.lck文件夹 原因:据说是因为电脑关机或重启前没有关闭虚拟机(出现问题之前的确是这样的…

SqlServer处于恢复挂起状态

一. 出现问题: 二. 总结: 如果数据库处于一个恢复挂起的状态,并且对数据库做脱机和分离的操作,报出数据库文件不可访问的错误,可能是因为数据库的数据文件和日志文件在数据库正常连接的情况下,文件所在的磁…

取消Windows 2003登录及关机提示

很多使用过Windows 2003操作系统,常常会在登录和关机比较麻烦,比如登录时要按CtrlAltDelete这三个键才能出现登录面板,接着还要输入用户名及密码才能登录。关机时会出现一个关机原因提示框,需要输入原因才能正常关机。Windows 200…

Android 拨打电话以及跳转到拨打电话界面

1 Android直接拨打电话 Intent Intent new Intent(Intent.ACTION_CALL,Uri.parse("tel:" phoneNumber));//直接拨打电话 startActivity(Intent); 2 Android跳转到拨号界面 Intent Intent new Intent(Intent.ACTION_CALL_BUTTON);//跳转到拨号界面 startActiv…

电话状态说明

打电话和来电话时处理 媒体音量的问题。现将逻辑说明如下: 1.来电话 监听类 PhoneStateListener 第一种状态:响铃:TelephonyManager.CALL_STATE_RINGING 第二种状态:接起电话: TelephonyManager.CALL_STATE_OFFHOO…