生成对抗网络pix2pixGAN

news/2024/10/20 5:48:13/

1.介绍

论文:Image-to-Image Translation with Conditional Adversarial Networks

论文地址:https://arxiv.org/abs/1611.07004

图像处理的很多问题都是将一张输入的图片转变为一张对应的 输出图片,比如灰度图、彩色图之间的转换、图像自动上色等。

什么是 pix2pixGAN:pix2pixGAN主要用于图像之间的转换,又称图像翻译。作者证明了这种方法在从标签图合成照片(synthesizing photos from label map)、从边缘图重建对象(reconstructing objects from edge maps)以及给图像上色(colorizing images)等多种任务中是有效的。

与普通GAN的区别:普通GAN的生成器G输入的是随机向量(噪声),输出是图像; 判别器D接收的输入是图像(生成的或是真实的),输出是对或者错 。这样G和D联手就能输出真实的图像。Pix2pixGAN本质上是一个cGAN,图片x作为此cGAN的条件, 输入到生成器G中。G的输出是生成的图片G(x)。 D则需要分辨出{x,G(x)}和{x, y}。其中x是需要转换的图片,y是x对应的真实图片。

2.生成器与判别器的设计

生成器G的设计:生成器G采用了Encoder-Decoder模型,参考U-Net的结构。

判别器D的设计:D中要输入成对的图像。判别器D的输入与cGAN中的不同,因为除了要生成真实图像之外,还要保证生成的图像和输入图像是匹配的。Pix2Pix论文中将判别器D实现为Patch-D,所谓Patch,是指无论生成的图像有多大,将其切分为多个固定大小的Patch输入进D去判断。这样设计的好处是:D的输入变小,计算量小,训练速度快。

3.损失函数

D网络损失函数(使用二元交叉熵损失BCELoss)

输入真实的成对图像希望判定为1,即{x, y};输入原图与生成图像希望判定为0,即{x,G(x)}。

G网络损失函数(使用二元交叉熵损失BCELoss和L1loss)

L1loss保证输入和输出之间的一致性;

输入原图与生成图像希望判定为1,即{x,G(x)}。

4.模型搭建 

import torch
from PIL import Image
import os
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from torchvision.datasets import ImageFolder
from torch.utils.data.dataset import Dataset
import tqdm
import globimgs_path = glob.glob('D:\cnn\All_Classfication/base_data/train/*.jpg') #获取训练集中的.jpg图片
annos_path = glob.glob('D:\cnn\All_Classfication/base_data/train/*.png') #获取训练集中的.png图片transform = transforms.Compose([transforms.ToTensor(),transforms.Resize((256, 256)),transforms.Normalize(mean=0.5, std=0.5)]) #Normalize为转化到-1~1之间# 定义数据读取
class GANDataset(Dataset):def __init__(self, imgs_path, annos_path): #初始化super(GANDataset, self).__init__()self.imgs_path     = imgs_path #定义属性self.annos_path   = annos_path#定义属性def __len__(self):return len(self.imgs_path)def __getitem__(self, index): #对数据切片img_path        = self.imgs_path[index]anno_path = self.annos_path[index]# 从文件中读取图像jpg         = Image.open(img_path)jpg         = transform(jpg)png         = Image.open(anno_path)png         = png.convert('RGB') #因为anno_path为单通道图片,使用convert方法还原回三通道png         = transform(png)return jpg, pngtrain_dataset = GANDataset(imgs_path, annos_path) #创建dataset
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True)jpg_batch, png_batch = next(iter(dataloader)) #查看,返回一个批次的训练数据
# print(jpg_bath.shape)
# print(png_bath.shape)# 查看训练集
# plt.figure(figsize=(8, 12))
# for i, (anno, img) in enumerate(zip(png_batch[:3], jpg_batch[:3])): #zip代表元组
#     # 因为dataset返回的数据是tensor,需要转为numpy格式,因为Normalize为转化到-1~1之间,所以加1再除以2将其转化到0~1之间
#     anno = (anno.permute(1, 2, 0).numpy() + 1) / 2
#     img = (img.permute(1, 2, 0).numpy() + 1) / 2
#     plt.subplot(3, 2, 2*i+1)
#     plt.title('input_img')
#     plt.imshow(anno)
#     plt.subplot(3, 2, 2*i+2)
#     plt.title('output_img')
#     plt.imshow(img)
# plt.show()#定义下采样模块
class Downsample(nn.Module):def __init__(self, in_channels, out_channels):super(Downsample, self).__init__()self.conv_relu = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),nn.LeakyReLU(inplace=True))self.bn = nn.BatchNorm2d(out_channels)def forward(self, x, is_bn=True): #is_bn用于确定是否使用bn层,默认为Truex = self.conv_relu(x)if is_bn:x = self.bn(x)return x#定义上采样模块
class Upsample(nn.Module):def __init__(self, in_channels, out_channels):super(Upsample, self).__init__()self.upconv_relu = nn.Sequential(nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),nn.LeakyReLU(inplace=True))self.bn = nn.BatchNorm2d(out_channels)def forward(self, x, is_drop=False): #is_drop用于确定是否使用drop层,默认为Falsex = self.upconv_relu(x)x = self.bn(x)if is_drop:x = F.dropout2d(x)return x# 定义生成器,包含6个下采样层,6个上采样层
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.down1 = Downsample(3, 64)     #3,256,256 -- 64,128,128self.down2 = Downsample(64, 128)   #64,128,128 -- 128,64,64self.down3 = Downsample(128, 256)  #128,64,64 -- 256,32,32self.down4 = Downsample(256, 512)  #256,32,32 -- 512,16,16self.down5 = Downsample(512, 512)  #512,16,16 -- 512,8,8self.down6 = Downsample(512, 512)  #512,8,8 -- 512,4,4self.up1 = Upsample(512, 512)      #512,4,4 -- 512,8,8self.up2 = Upsample(1024, 512)     #1024,8,8 -- 512,16,16self.up3 = Upsample(1024, 256)     #1024,16,16 -- 256,32,32self.up4 = Upsample(512, 128)      #512,32,32 -- 128,64,64self.up5 = Upsample(256, 64)       #256,64,64 -- 64,128,128#128,128,128 -- 3,256,256self.last = nn.ConvTranspose2d(128, 3, kernel_size=3, stride=2, padding=1, output_padding=1)def forward(self, x):x1 = self.down1(x)x2 = self.down2(x1)x3 = self.down3(x2)x4 = self.down4(x3)x5 = self.down5(x4)x6 = self.down6(x5)x6 = self.up1(x6, is_drop=True)x6 = torch.cat([x6, x5], dim=1)x6 = self.up2(x6, is_drop=True)x6 = torch.cat([x6, x4], dim=1)x6 = self.up3(x6, is_drop=True)x6 = torch.cat([x6, x3], dim=1)x6 = self.up4(x6)x6 = torch.cat([x6, x2], dim=1)x6 = self.up5(x6)x6 = torch.cat([x6, x1], dim=1)x6 = torch.tanh(self.last(x6))return x6# 定义判别器   将条件(anno)与图片(生成的或真实的)同时输入到判别器中进行判定  concat
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.down1 = Downsample(6, 64)self.down2 = Downsample(64, 128)self.conv1 = nn.Conv2d(128, 256, 3)self.bn = nn.BatchNorm2d(256)self.last = nn.Conv2d(256, 1, 3)# 判别器的输入为成对的图片,anno为结构图,img为真实的或生成的图片def forward(self, anno, img):x = torch.cat([anno, img], dim=1) #batch_size,6,256,256x = self.down1(x, is_bn=False) #batch_size,64,128,128x = self.down2(x) #batch_size,128,64,64x = F.dropout2d(self.bn(F.leaky_relu(self.conv1(x)))) #batch_size,256,62,62x = torch.sigmoid(self.last(x)) #batch_size,1,60,60return xdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")gen = Generator().to(device)
dis = Discriminator().to(device)# 判别器优化器
d_optimizer = torch.optim.Adam(dis.parameters(), lr=1e-4, betas=(0.5, 0.999)) #通过减小判别器的学习率降低其能力
# 生成器优化器
g_optimizer = torch.optim.Adam(gen.parameters(), lr=1e-3, betas=(0.5, 0.999))# 绘图函数,将每一个epoch中生成器生成的图片绘制
def gen_img_plot(model, epoch, test_anno, test_real): # model为Generator,test_anno为结构图,test_real为真实图片generate = model(test_anno).permute(0, 2, 3, 1).detach().cpu().numpy() #detach()截断梯度,将通道维度放在最后test_anno = test_anno.permute(0, 2, 3, 1).cpu().numpy() #1,3,256,256 -- 1,256,256,3test_real = test_real.permute(0, 2, 3, 1).cpu().numpy() #1,3,256,256 -- 1,256,256,3plt.figure(figsize=(10, 10))title = ['Input image', 'Ground truth', 'Generate image']display_list0 = [test_anno[0], test_real[0], generate[0]]for i in range(3):plt.subplot(3, 3, i + 1)plt.title(title[i])plt.imshow((display_list0[i]+1)/2) #从-1~1 --> 0~1plt.axis('off')display_list1 = [test_anno[1], test_real[1], generate[1]]for i in range(3,6):plt.subplot(3, 3, i + 1)# plt.title(title[i])plt.imshow((display_list1[i-3]+1)/2) #从-1~1 --> 0~1plt.axis('off')display_list2 = [test_anno[2], test_real[2], generate[2]]for i in range(6,9):plt.subplot(3, 3, i + 1)# plt.title(title[i])plt.imshow((display_list2[i-6]+1)/2) #从-1~1 --> 0~1plt.axis('off')# plt.show()plt.savefig('./imageP2P/image_at_{}.png'.format(epoch))test_imgs_path = glob.glob('D:\cnn\All_Classfication/base_data/val/*.jpg') #获取验证集中的.jpg图片
test_annos_path = glob.glob('D:\cnn\All_Classfication/base_data/val/*.png') #获取验证集中的.png图片test_dataset = GANDataset(test_imgs_path, test_annos_path) #创建dataset
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=True)imgs_batch, annos_batch = next(iter(test_dataloader)) #查看,返回一个批次的测试数据
# print(jpg_bath.shape)
# print(png_bath.shape)# 查看测试集
# plt.figure(figsize=(8, 12))
# for i, (anno, img) in enumerate(zip(annos_batch[:3], imgs_batch[:3])): #zip代表元组
#     # 因为dataset返回的数据是tensor,需要转为numpy格式,因为Normalize为转化到-1~1之间,所以加1再除以2将其转化到0~1之间
#     anno = (anno.permute(1, 2, 0).numpy() + 1) / 2
#     img = (img.permute(1, 2, 0).numpy() + 1) / 2
#     plt.subplot(3, 2, 2*i+1)
#     plt.title('input_img')
#     plt.imshow(anno)
#     plt.subplot(3, 2, 2*i+2)
#     plt.title('output_img')
#     plt.imshow(img)
# plt.show()annos_batch, imgs_batch = annos_batch.to(device), imgs_batch.to(device)# 定义cGAN损失
loss_fn = torch.nn.BCELoss() # 二元交叉熵损失
LAMBDA = 7 #L1损失的权重# pix2pixGAN训练
D_loss = []
G_loss = []for epoch in range(100):D_epoch_loss = 0 #记录判别器每个epoch损失G_epoch_loss = 0 #记录生成器每个epoch损失count = len(dataloader) #len(dataloader)返回批次数count1 = len(train_dataset) #len(train_dataset)返回样本数for step, (imgs, annos) in enumerate(tqdm.tqdm(dataloader)): #注意dataloader输出的图片和标签的顺序annos = annos.to(device)imgs = imgs.to(device)#-------------------------------------## 判别器损失d_optimizer.zero_grad()disc_real_output = dis(annos, imgs) #输入真实的成对图像希望判定为1,即{x, y}d_real_loss = loss_fn(disc_real_output, torch.ones_like(disc_real_output, device=device))d_real_loss.backward()  # 反向传播gen_output = gen(annos) #结构图通过生成器生成图片disc_gen_output = dis(annos, gen_output.detach()) #输入原图与生成图像希望判定为0,即{x,G(x)}d_fake_loss = loss_fn(disc_gen_output, torch.zeros_like(disc_gen_output, device=device))d_fake_loss.backward()  # 反向传播# 判别器总损失disc_loss = d_real_loss + d_fake_lossd_optimizer.step() #优化# -------------------------------------## -------------------------------------## 生成器损失g_optimizer.zero_grad()disc_gen_out = dis(annos, gen_output) #输入原图与生成图像希望判定为1,即{x,G(x)}gen_loss_celoss = loss_fn(disc_gen_out, torch.ones_like(disc_gen_out, device=device))gen_l1_loss = torch.mean(torch.abs(gen_output - imgs)) #L1loss度量生成图像与原结构图之间的距离# 生成器总损失gen_loss = gen_loss_celoss + LAMBDA*gen_l1_lossgen_loss.backward()  #反向传播g_optimizer.step() #优化# -------------------------------------#with torch.no_grad():D_epoch_loss += disc_loss.item()  # 将每一个批次的loss累加G_epoch_loss += gen_loss.item()  # 将每一个批次的loss累加with torch.no_grad():D_epoch_loss /= count  # 求得每一轮的平均lossG_epoch_loss /= count  # 求得每一轮的平均lossD_loss.append(D_epoch_loss)G_loss.append(G_epoch_loss)print('epoch:', epoch)gen_img_plot(gen, epoch, annos_batch, imgs_batch)plt.figure(figsize=(10, 10))plt.plot(range(1, len(D_loss) + 1), D_loss, label='D_loss')plt.plot(range(1, len(G_loss) + 1), G_loss, label='G_loss')plt.xlabel('epoch')  # 横轴名称plt.legend()plt.savefig('./imageP2P/loss.png')  # 保存图片


http://www.ppmy.cn/news/53883.html

相关文章

计算机网络-如何寻找目标主机

视频参考链接:计算机网络-如何寻找目标计算机?_哔哩哔哩_bilibili 在互联网中如果使计算机A与计算机B如何进行通信,又是如何找到目标的计算机主机呢? 首先最简单的通信就是两台计算机中间加一根网线,那么这两台计算机…

软件测试项目去哪里找?我都给你整理好了【源码+操作视频】

目录 一、引言 二、测试任务 三、测试进度 四、测试资源 五、测试策略 六、测试完成标准 七、风险和约束 八、问题严重程度描述和响应时间规范 九、测试的主要角色和职责 ​有需要实战项目的评论区留言吧! 软件测试是使用人工或者自动的手段来运行或者测定…

悲观锁、乐观锁、自旋锁和读写锁

悲观锁和乐观锁 悲观锁:在每次取数据时,总是担心数据会被其他线程修改,所以会在取数据前先加锁(读锁,写锁,行 锁等),当其他线程想要访问数据时,被阻塞挂起。&#xff08…

创新案例|探索 Snyk 的 PLG 团队1.6倍年度 ARR 增长背后的策略

组织架构不匹配、权责分配不清晰以及团队协作无机制是推进PLG业务面临的三大核心挑战,而安全软件公司Snyk以其指数级营收和估值增长的成功实践证明,构建合适且高效团队是助力PLG创新实现高速增长的关键,其经验值得借鉴。本文将通过分析Synk如…

Bindiff工具使用-[GDOUCTF 2023]L!s!

目录 题目: 学到的点: 题目: 打了GDOUCTF的比赛(被暴打了hhh),学到很多新东西,这里总结一下 Diff的文件是ida数据库文件,选择i64或者idb文件进行Diff 打开附件,有两个文件,一个…

Golang校验字符串是否JSON格式方法json.Valid源码解析

上篇文章《Golang中如何校验字符串是否为JSON格式?》主要讲解了使用json.Valid校验字符串是否JSON格式的使用方法,本文来剖析一下json.Valid方法的源码。 json.Valid方法源码 json.Valid方法定义: // Valid reports whether data is a val…

Linux拓展:链接库

一.说明 本篇博客介绍Linux操作系统下的链接库相关知识,由于相关概念已在Windows下链接库一文中介绍,本篇博客直接上操作。 二.静态链接库的创建和使用 1.提前看 这里主要介绍的是C语言的链接库技术,而在Linux下实现C语言程序&#xff0c…

支持中英双语和多种插件的开源对话语言模型,160亿参数

一、开源项目简介 MOSS是一个支持中英双语和多种插件的开源对话语言模型,moss-moon系列模型具有160亿参数,在FP16精度下可在单张A100/A800或两张3090显卡运行,在INT4/8精度下可在单张3090显卡运行。MOSS基座语言模型在约七千亿中英文以及代码…