深度学习:CycleGAN图像风格迁移转换

news/2024/12/21 23:10:58/

目录

基础概念

模型工作流程

循环一致性

几个基本概念

假图像(Fake Image)

重建图像(Reconstructed Image)

身份映射图像(Identity Mapping Image)

CyclyGAN损失函数

对抗损失

身份鉴别损失

CycleGAN的应用

基于MindSpore的CycleGAN

数据集

生成器的基本架构

构建生成器基本块

 定义ResNet的残差块

定义基于ResNet的生成器

定义判别器

定义优化器和损失函数

前向计算

梯度计算和反向传播

模型训练

模型推理


基础概念

CycleGAN是一种GAN的变体,它被设计用来在没有成对训练数据的情况下学习两种不同域之间的图像到图像的转换,不需要同一场景或物体在两个不同域中的对应图像。

CycleGAN由Jun-Yan Zhu等人在2017年提出。

CycleGAN的模型架构主要由两组生成器和判别器组成,每组负责一个方向上的图像转换。

具体来说,假设我们有两个不同的图像领域X(比如马的照片)和Y(比如斑马的照片),那么CycleGAN将包含以下组件:

  1. 生成器G:负责将图像从领域X转换到领域Y。
  2. 生成器F:负责将图像从领域Y转换回领域X。
  3. 判别器DY:用于区分领域Y中的真实图像与通过生成器G从领域X转换来的假图像。
  4. 判别器DX:用于区分领域X中的真实图像与通过生成器F从领域Y转换来的假图像。

模型工作流程

  • 当一张来自领域X的图片x被输入到生成器G时,它会产生一张看起来像是属于领域Y的图片G(x)。
  • 判别器DY会尝试判断G(x)是否是真实的领域Y图片。
  • 同样地,当一张来自领域Y的图片y被输入到生成器F时,它会产生一张看起来像是属于领域X的图片F(y)。
  • 判别器DX会尝试判断F(y)是否是真实的领域X图片。

循环一致性

为了确保生成器G和F不仅能够成功地进行单向转换,而且还能保持原始图像的信息不丢失,CycleGAN引入了循环一致性的概念。

前向循环一致性

对于源域中的图像x,首先通过生成器G生成转换图像G(x),随后通过生成器F将G(x)转换回源域F(G(x))。循环一致性损失计算F(G(x))与原始图像x之间的差异。

反向循环一致性

对于目标域中的图像y,首先通过生成器F生成一个转换后图像F(y),然后通过生成器G将F(y)转换回目标域G(F(y))。计算G(F(y))与原始图像y之间的差异。

对抗性损失

生成器G和F需要生成足够真实的图片七篇对应的判别器DY和DX。

几个基本概念

假图像(Fake Image)

假图像是通过生成器网络将一个域的图像转换成另一个域的图像。例如,在人脸年龄变化的任务中,如果有一个年轻人的脸部图片(属于年轻域),生成器可以生成一张看起来更老的脸部图片(属于年老域)。这个新生成的老年脸部图片就是假图像。

在接下来的代码中,fake_a 是从域 B 的真实图像 img_b 通过生成器 net_rg_b 生成的假图像,而 fake_b 是从域 A 的真实图像 img_a 通过生成器 net_rg_a 生成的假图像。

重建图像(Reconstructed Image)

重建图像是指将假图像再次通过相应的生成器网络转换回原始域的过程。这样做是为了确保图像在跨域转换后仍然能够恢复其原始特征。

例如,如果 fake_b 是从 img_a 生成的,那么再用 net_rg_b 将 fake_b 转换回域 A 得到的图像 rec_a 应该尽可能地接近 img_a

这种循环一致性损失有助于保持图像内容的一致性,即使在跨域转换过程中也不会丢失重要信息。

在接下来的代码中,rec_a 是由 fake_b 通过 net_rg_b 重新转换得到的图像,而 rec_b 是由 fake_a 通过 net_rg_a 重新转换得到的图像。

身份映射图像(Identity Mapping Image)

身份映射图像是指将一个域的真实图像直接输入到对应域的生成器网络中,期望输出与输入相同或非常相似的图像。这用于训练生成器学习如何在不改变图像的情况下保持图像不变。

这种损失被称为身份损失,它鼓励生成器在不需要进行跨域转换时保持图像不变。

在接下来的代码中,identity_a 是将域 A 的真实图像 img_a 直接通过 net_rg_b 得到的输出,而 identity_b 是将域 B 的真实图像 img_b 直接通过 net_rg_a 得到的输出。

CyclyGAN损失函数

CycleGAN 的损失函数设计得比较复杂,旨在解决无监督图像到图像的转换问题。它的损失函数由主要两部分组成:对抗损失(Adversarial Loss)和循环一致性损失(Cycle Consistency Loss)。同时可以包括身份鉴别损失(Identity Mapping Loss)

对抗损失

对抗损失来源于生成对抗网络(GANs)的基本概念。它包括生成器(G)和判别器(D)两个部分。

生成器 G 尝试生成看起来像目标域 Y 的图像,而判别器 D 则试图区分真实的目标域 Y 图像与生成的假图像。

对于 CycleGAN 来说,有两个生成器 G:X→Y 和 F:Y→X,以及两个对应的判别器 DY和 DX。

对抗损失可以表示为:

同样地,对于另一个方向也有一个类似的损失: 

循环一致性损

循环一致性损失是为了保证从一个域转换到另一个域后,再转回原域时,图像应该尽可能接近原始输入。

这个损失鼓励 G(F(y))≈y 和 F(G(x))≈x。

循环一致性损失表示为:

身份鉴别损失

除了上述两种损失外,CycleGAN有时还会引入一种额外的损失来增强模型的表现,即身份映射损失。

这种损失鼓励生成器保留那些已经属于目标域的图像不变。如果将一个目标域的图像输入到对应的生成器中,输出应该和输入相同。

 身份鉴别损失表示为:

综合这些损失,CycleGAN的整体损失函数通常是这样构成的: 

L(G, F, D_X, D_Y) = L_{GAN}(G, D_Y, X, Y) + L_{GAN}(F, D_X, Y, X) + \lambda (L_{cyc}(G, F) + L_{cyc}(F, G)) + \lambda_{id} (L_{identity}(G, F, X, Y))

其中 λ 和 λid是超参数,用于平衡不同损失项的重要性。 

CycleGAN的应用

风格迁移:讲真实照片变为莫奈风格的艺术作品

物体转换:将马变成斑马、将苹果变成橘子

基于MindSpore的CycleGAN

数据集

# 数据集
'''
本案例使用的数据集里面的图片来源于ImageNet,该数据集共有17个数据包,本文只使用了其中的苹果橘子部分。
图像被统一缩放为256×256像素大小,
其中用于训练的苹果图片996张、橘子图片1020张,用于测试的苹果图片266张、橘子图片248张。对数据进行了随机裁剪、水平随机翻转和归一化的预处理,
为了将重点聚焦到模型,此处将数据预处理后的结果转换为 MindRecord 格式的数据,
以省略大部分数据预处理的代码。
'''
from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/CycleGAN_apple2orange.zip"download(url, ".", kind="zip", replace=True)# 数据集
'''
本案例使用的数据集里面的图片来源于ImageNet,该数据集共有17个数据包,本文只使用了其中的苹果橘子部分。
图像被统一缩放为256×256像素大小,
其中用于训练的苹果图片996张、橘子图片1020张,用于测试的苹果图片266张、橘子图片248张。对数据进行了随机裁剪、水平随机翻转和归一化的预处理,
为了将重点聚焦到模型,此处将数据预处理后的结果转换为 MindRecord 格式的数据,
以省略大部分数据预处理的代码。
'''
from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/CycleGAN_apple2orange.zip"download(url, ".", kind="zip", replace=True)# 数据集可视化
import numpy as np
import matplotlib.pyplot as pltmean = 0.5 * 255
std = 0.5 * 255plt.figure(figsize=(12, 5), dpi=60)
for i, data in enumerate(dataset.create_dict_iterator()):if i < 5:show_images_a = data["image_A"].asnumpy()show_images_b = data["image_B"].asnumpy()plt.subplot(2, 5, i+1)show_images_a = (show_images_a[0] * std + mean).astype(np.uint8).transpose((1, 2, 0))plt.imshow(show_images_a)plt.axis("off")plt.subplot(2, 5, i+6)show_images_b = (show_images_b[0] * std + mean).astype(np.uint8).transpose((1, 2, 0))plt.imshow(show_images_b)plt.axis("off")else:break
plt.show()

生成器的基本架构

构建生成器基本块

# 构建生成器
# 生成器采用ResNet模型结构
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Normal
import mindspore as ms
# 初始化权重的方法
weight_init = Normal(sigma=0.01)# 定义ConvNormReLU块
class ConvNormReLU(nn.Cell):def __init__(self, input_channel, out_planes, kernel_size=4, stride=2, alpha=0.2, norm_mode='instance',pad_mode='CONSTANT', use_relu=True, padding=None, transpose=False):super(ConvNormReLU, self).__init__()norm = nn.BatchNorm2d(out_planes)if norm_mode == 'instance':# 参数affine用于控制是否对归一化后的数据应用可学习的仿射变换(即缩放和平移)。# 当设置affine=False时,不会对归一化后的数据进行任何线性变换。norm = nn.BatchNorm2d(out_planes, affine=False)has_bias = (norm_mode == 'instance')if padding is None:padding = (kernel_size - 1) // 2if pad_mode == 'CONSTANT':# 如果需要转置卷积(上采样)构建转置卷积层if transpose:conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='same',has_bias=has_bias, weight_init=weight_init)else:# 无需转置卷积(下采样)conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',has_bias=has_bias, padding=padding, weight_init=weight_init)# 组合卷积层和正则化层layers = [conv, norm]else:# 创建了一个四元组列表,每个元组表示一个维度上的前后填充量。# (0, 0) 对应于批量大小和通道数维度,意味着在这两个维度上不做任何填充。# (padding, padding) 分别对应高度和宽度维度,在这两个维度上都会添加相同数量的填充。# 高度和宽度的两侧都会各增加1个像素的填充。paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))# nn.Pad类创建了一个填充层实例。# paddings 参数指定了具体的填充方式,按照上面定义的paddings变量。pad = nn.Pad(paddings=paddings, mode=pad_mode)if transpose:conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='pad',has_bias=has_bias, weight_init=weight_init)else:conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',has_bias=has_bias, weight_init=weight_init)layers = [pad, conv, norm]# 如果需要激活函数,并判断是哪种激活函数if use_relu:relu = nn.ReLU()if alpha > 0:relu = nn.LeakyReLU(alpha)layers.append(relu)# 组装模型self.features = nn.SequentialCell(layers)def construct(self, x):output = self.features(x)return output

 定义ResNet的残差块

# 定义ResNet的残差块
class ResidualBlock(nn.Cell):def __init__(self, dim, norm_mode='instance', dropout=False, pad_mode='CONSTANT'):super(ResidualBlock, self).__init__()self.conv1 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode)self.conv2 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode, use_relu=False)self.dropout = dropoutif dropout:self.dropout = nn.Dropout(p=0.5)def construct(self, x):out = self.conv1(x)if self.dropout:out = self.dropout(out)out = self.conv2(out)# 返回 x + out 的做法是实现残差学习的关键。这个设计是为了让网络能够更容易地学习到恒等映射(identity mapping)# 从而帮助解决深层网络训练中的梯度消失问题,并允许网络构建得更深而不会导致性能下降。return x + out

定义基于ResNet的生成器

# 定义基于ResNet的生成器
class ResNetGenerator(nn.Cell):def __init__(self, input_channel=3, output_channel=64, n_layers=9, alpha=0.2, norm_mode='instance', dropout=False,pad_mode="CONSTANT"):super(ResNetGenerator, self).__init__()# 数据集图像输入后经过的第一个网络self.conv_in = ConvNormReLU(input_channel, output_channel, 7, 1, alpha, norm_mode, pad_mode=pad_mode)# 随后对数据进行两次下采样self.down_1 = ConvNormReLU(output_channel, output_channel * 2, 3, 2, alpha, norm_mode)self.down_2 = ConvNormReLU(output_channel * 2, output_channel * 4, 3, 2, alpha, norm_mode)# 残差网络有9个残差块layers = [ResidualBlock(output_channel * 4, norm_mode, dropout=dropout, pad_mode=pad_mode)] * n_layers# 组装残差网络self.residuals = nn.SequentialCell(layers)# 再将图片进行上采样(转置卷积)self.up_1 = ConvNormReLU(output_channel * 2, output_channel, 3, 2, alpha, norm_mode, transpose=True)self.up_2 = ConvNormReLU(output_channel * 4, output_channel * 2, 3, 2, alpha, norm_mode, transpose=True)# 定义输出层if pad_mode == 'CONSTANT':self.conv_out = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad',padding=3, weight_init=weight_init)else:pad = nn.Pad(paddings=((0, 0), (0, 0), (3, 3), (3, 3)), mode=pad_mode)conv = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad', weight_init=weight_init)self.conv_out = nn.SequentialCell([pad, conv])def construct(self, x):x = self.conv_in(x)x = self.down_1(x)x = self.down_2(x)x = self.residuals(x)x = self.up_2(x)x = self.up_1(x)output = self.conv_out(x)# 将输出压制(-1, 1)return ops.tanh(output)# 实例化生成器
# 创建生成器G和F
net_rg_a = ResNetGenerator()
net_rg_a.update_parameters_name('net_rg_a.')net_rg_b = ResNetGenerator()
net_rg_b.update_parameters_name('net_rg_b.')

定义判别器

# 创建判别器
# 判别器其实是一个二分类网络模型,输出判定该图像为真实图的概率。
# 网络模型使用的是 Patch 大小为 70x70 的 PatchGANs 模型。
class Discriminator(nn.Cell):def __init__(self, input_channel=3, output_channel=64, n_layers=3, alpha=0.2, norm_mode='instance'):super(Discriminator, self).__init__()# 定义卷积核大小kernel_size = 4layers = [nn.Conv2d(input_channel, output_channel, kernel_size, 2, pad_mode='pad', padding=1, weight_init=weight_init),nn.LeakyReLU(alpha)]# 初始化倍增因子nf_mult = output_channel# 使用倍增因子逐步增大通道数for i in range(1, n_layers):nf_mult_prev = nf_multnf_mult = min(2 ** i, 8) * output_channellayers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))nf_mult_prev = nf_multnf_mult = min(2 ** n_layers, 8) * output_channellayers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))# 输出层layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1, weight_init=weight_init))# 组装模型self.features = nn.SequentialCell(layers)def construct(self, x):output = self.features(x)return output# 判别器初始化
# 初始化两个判别器
net_d_a = Discriminator()
net_d_a.update_parameters_name('net_d_a.')net_d_b = Discriminator()
net_d_b.update_parameters_name('net_d_b.')

定义优化器和损失函数

# 构建生成器,判别器优化器
optimizer_rg_a = nn.Adam(net_rg_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_rg_b = nn.Adam(net_rg_b.trainable_params(), learning_rate=0.0002, beta1=0.5)optimizer_d_a = nn.Adam(net_d_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_d_b = nn.Adam(net_d_b.trainable_params(), learning_rate=0.0002, beta1=0.5)
# 两个损失函数
loss_fn = nn.MSELoss(reduction='mean')
l1_loss = nn.L1Loss('mean')def gan_loss(predict, target):# 全一表示真实数据target = ops.ones_like(predict) * targetloss = loss_fn(predict, target)return loss

前向计算

# 前向计算def generator(img_a, img_b):# img_a 是来自域 A 的真实图像# img_b 是来自域 B 的真实图像# 使用网络 net_rg_b 将域 B 的图像 img_b 转换为域 A 的假图像 fake_afake_a = net_rg_b(img_b)# 使用网络 net_rg_a 将域 A 的图像 img_a 转换为域 B 的假图像 fake_bfake_b = net_rg_a(img_a)# 再次使用网络 net_rg_b 将生成的假图像 fake_b 重新转换回域 A 的重建图像 rec_arec_a = net_rg_b(fake_b)# 再次使用网络 net_rg_a 将生成的假图像 fake_a 重新转换回域 B 的重建图像 rec_brec_b = net_rg_a(fake_a)# 使用网络 net_rg_b 直接处理域 A 的图像 img_a,期望输出与输入相同或相似,这是为了保持同一性identity_a = net_rg_b(img_a)# 使用网络 net_rg_a 直接处理域 B 的图像 img_b,期望输出与输入相同或相似,这也是为了保持同一性identity_b = net_rg_a(img_b)# 返回生成的假图像、重建图像和身份映射图像# 用于计算循环一致性return fake_a, fake_b, rec_a, rec_b, identity_a, identity_b# 定义不同类型的损失权重
lambda_a = 10.0  # 循环一致性损失 A 到 B 的权重
lambda_b = 10.0  # 循环一致性损失 B 到 A 的权重
lambda_idt = 0.5  # 身份映射损失的权重def generator_forward(img_a, img_b):# 创建一个表示真实的标签 Tensortrue = Tensor(True, dtype.bool_)# 调用先前定义的 generator 函数来获取生成的图像及其重建版本fake_a, fake_b, rec_a, rec_b, identity_a, identity_b = generator(img_a, img_b)# 判别器损失loss_g_a = gan_loss(net_d_b(fake_b), true)loss_g_b = gan_loss(net_d_a(fake_a), true)# 循环一致性损失loss_c_a = l1_loss(rec_a, img_a) * lambda_aloss_c_b = l1_loss(rec_b, img_b) * lambda_b# 身份映射损失loss_idt_a = l1_loss(identity_a, img_a) * lambda_a * lambda_idtloss_idt_b = l1_loss(identity_b, img_b) * lambda_b * lambda_idt# 整合损失loss_g = loss_g_a + loss_g_b + loss_c_a + loss_c_b + loss_idt_a + loss_idt_b# 通过这种方式,生成器不仅学习如何欺骗判别器,还要保证图像经过跨域转换后能够准确地恢复原样(循环一致性),以及在不改变域的情况下尽可能保留原始图像(身份映射)。return fake_a, fake_b, loss_g, loss_g_a, loss_g_b, loss_c_a, loss_c_b, loss_idt_a, loss_idt_b
# 获取生成器的总损失
def generator_forward_grad(img_a, img_b):_, _, loss_g, _, _, _, _, _, _ = generator_forward(img_a, img_b)return loss_g# 这个函数同时处理来自域 A 和域 B 的图像,并计算两个判别器的总损失。
def discriminator_forward(img_a, img_b, fake_a, fake_b):# 假图像标签false = Tensor(False, dtype.bool_)# 真图像标签true = Tensor(True, dtype.bool_)# 判别器ad_fake_a = net_d_a(fake_a)d_img_a = net_d_a(img_a)# 判别器bd_fake_b = net_d_b(fake_b)d_img_b = net_d_b(img_b)# 计算判别器a的损失loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)# 计算判别器b的损失loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)# 加权计算总损失loss_d = (loss_d_a + loss_d_b) * 0.5return loss_d
# 只处理域 A 的图像,计算 net_d_a 判别器的损失。
def discriminator_forward_a(img_a, fake_a):false = Tensor(False, dtype.bool_)true = Tensor(True, dtype.bool_)d_fake_a = net_d_a(fake_a)d_img_a = net_d_a(img_a)loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)return loss_d_a
# 只处理域 B 的图像,计算 net_d_b 判别器的损失。
def discriminator_forward_b(img_b, fake_b):false = Tensor(False, dtype.bool_)true = Tensor(True, dtype.bool_)d_fake_b = net_d_b(fake_b)d_img_b = net_d_b(img_b)loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)return loss_d_b# 保留了一个图像缓冲区,用来存储之前创建的50个图像
'''
为了减少模型振荡,遵循 Shrivastava 等人的策略[,
使用生成器生成图像的历史数据而不是生成器生成的最新图像数据来更新鉴别器。
'''
pool_size = 50
def image_pool(images):num_imgs = 0image1 = []if isinstance(images, Tensor):images = images.asnumpy()return_images = []for image in images:if num_imgs < pool_size:num_imgs = num_imgs + 1image1.append(image)return_images.append(image)else:if random.uniform(0, 1) > 0.5:random_id = random.randint(0, pool_size - 1)tmp = image1[random_id].copy()image1[random_id] = imagereturn_images.append(tmp)else:return_images.append(image)output = Tensor(return_images, ms.float32)if output.ndim != 4:raise ValueError("img should be 4d, but get shape {}".format(output.shape))return output

梯度计算和反向传播

from mindspore import value_and_grad
# 梯度计算和反向传播
# 实例化求梯度的方法
# 生成器a梯度
grad_g_a = value_and_grad(generator_forward_grad, None, net_rg_a.trainable_params())
# 生成器b梯度
grad_g_b = value_and_grad(generator_forward_grad, None, net_rg_b.trainable_params())
# 判别器a梯度
grad_d_a = value_and_grad(discriminator_forward_a, None, net_d_a.trainable_params())
# 判别器d梯度
grad_d_b = value_and_grad(discriminator_forward_b, None, net_d_b.trainable_params())# 计算生成器的梯度,反向传播更新参数
def train_step_g(img_a, img_b):# 对于 net_d 网络中的所有参数,停止计算它们的梯度。net_d_a.set_grad(False)net_d_b.set_grad(False)fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib = generator_forward(img_a, img_b)_, grads_g_a = grad_g_a(img_a, img_b)_, grads_g_b = grad_g_b(img_a, img_b)optimizer_rg_a(grads_g_a)optimizer_rg_b(grads_g_b)return fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib# 计算判别器的梯度,反向传播更新参数
def train_step_d(img_a, img_b, fake_a, fake_b):net_d_a.set_grad(True)net_d_b.set_grad(True)loss_d_a, grads_d_a = grad_d_a(img_a, fake_a)loss_d_b, grads_d_b = grad_d_b(img_b, fake_b)loss_d = (loss_d_a + loss_d_b) * 0.5optimizer_d_a(grads_d_a)optimizer_d_b(grads_d_b)return loss_d

模型训练

import os  # 操作系统接口模块
import time  # 时间处理模块
import random  # 用于生成随机数
import numpy as np  # 数值计算库
from PIL import Image  # 图像处理库
from mindspore import Tensor, save_checkpoint  # MindSpore 库中的张量和保存检查点功能
from mindspore import dtype  # MindSpore 库中的数据类型定义# 由于时间原因,epochs设置为1,可根据需求进行调整
epochs = 1  # 训练轮次
save_step_num = 80  # 每隔多少步打印一次信息
save_checkpoint_epochs = 1  # 每隔多少个epoch保存一次模型
save_ckpt_dir = './train_ckpt_outputs/'  # 保存模型检查点的目录print('Start training!')  # 打印开始训练的信息for epoch in range(epochs):  # 对每个epoch进行迭代g_loss = []  # 初始化生成器损失列表d_loss = []  # 初始化判别器损失列表start_time_e = time.time()  # 记录当前epoch开始的时间for step, data in enumerate(dataset.create_dict_iterator()):  # 对数据集中的每一步进行迭代start_time_s = time.time()  # 记录当前步开始的时间img_a = data["image_A"]  # 从数据中获取域A的图像img_b = data["image_B"]  # 从数据中获取域B的图像res_g = train_step_g(img_a, img_b)  # 调用生成器的训练步骤并获取结果fake_a = res_g[0]  # 获取生成的假图像Afake_b = res_g[1]  # 获取生成的假图像B# 调用判别器的训练步骤,使用图像池来存储假图像,并传递给判别器res_d = train_step_d(img_a, img_b, image_pool(fake_a), image_pool(fake_b))loss_d = float(res_d.asnumpy())  # 将判别器的损失转换为浮点数step_time = time.time() - start_time_s  # 计算当前步的耗时# 将生成器的其他损失项转换为浮点数res = []for item in res_g[2:]:res.append(float(item.asnumpy()))g_loss.append(res[0])  # 添加总的生成器损失到列表d_loss.append(loss_d)  # 添加判别器损失到列表if step % save_step_num == 0:  # 如果是需要打印信息的步数print(f"Epoch:[{int(epoch + 1):>3d}/{int(epochs):>3d}], "  # 打印当前epoch/总epochf"step:[{int(step):>4d}/{int(datasize):>4d}], "  # 打印当前步/总步数f"time:{step_time:>3f}s,\n"  # 打印当前步耗时f"loss_g:{res[0]:.2f}, loss_d:{loss_d:.2f}, "  # 打印生成器和判别器的损失f"loss_g_a: {res[1]:.2f}, loss_g_b: {res[2]:.2f}, "  # 打印生成器A和B的GAN损失f"loss_c_a: {res[3]:.2f}, loss_c_b: {res[4]:.2f}, "  # 打印循环一致性损失f"loss_idt_a: {res[5]:.2f}, loss_idt_b: {res[6]:.2f}")  # 打印身份映射损失epoch_cost = time.time() - start_time_e  # 计算当前epoch的总耗时per_step_time = epoch_cost / datasize  # 计算每步的平均耗时mean_loss_d, mean_loss_g = sum(d_loss) / datasize, sum(g_loss) / datasize  # 计算平均损失# 打印当前epoch的平均损失和耗时print(f"Epoch:[{int(epoch + 1):>3d}/{int(epochs):>3d}], "f"epoch time:{epoch_cost:.2f}s, per step time:{per_step_time:.2f}, "f"mean_g_loss:{mean_loss_g:.2f}, mean_d_loss:{mean_loss_d :.2f}")if epoch % save_checkpoint_epochs == 0:  # 如果是需要保存检查点的epochos.makedirs(save_ckpt_dir, exist_ok=True)  # 确保保存目录存在# 保存生成器和判别器的模型检查点save_checkpoint(net_rg_a, os.path.join(save_ckpt_dir, f"g_a_{epoch}.ckpt"))save_checkpoint(net_rg_b, os.path.join(save_ckpt_dir, f"g_b_{epoch}.ckpt"))save_checkpoint(net_d_a, os.path.join(save_ckpt_dir, f"d_a_{epoch}.ckpt"))save_checkpoint(net_d_b, os.path.join(save_ckpt_dir, f"d_b_{epoch}.ckpt"))print('End of training!')  # 打印训练结束的信息

模型推理

import os
from PIL import Image
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
from mindspore import load_checkpoint, load_param_into_net# 加载权重文件
def load_ckpt(net, ckpt_dir):param_GA = load_checkpoint(ckpt_dir)load_param_into_net(net, param_GA)g_a_ckpt = './CycleGAN_apple2orange/ckpt/g_a.ckpt'
g_b_ckpt = './CycleGAN_apple2orange/ckpt/g_b.ckpt'load_ckpt(net_rg_a, g_a_ckpt)
load_ckpt(net_rg_b, g_b_ckpt)# 图片推理
fig = plt.figure(figsize=(11, 2.5), dpi=100)
def eval_data(dir_path, net, a):def read_img():for dir in os.listdir(dir_path):path = os.path.join(dir_path, dir)img = Image.open(path).convert('RGB')yield img, dirdataset = ds.GeneratorDataset(read_img, column_names=["image", "image_name"])trans = [vision.Resize((256, 256)), vision.Normalize(mean=[0.5 * 255] * 3, std=[0.5 * 255] * 3), vision.HWC2CHW()]dataset = dataset.map(operations=trans, input_columns=["image"])dataset = dataset.batch(1)for i, data in enumerate(dataset.create_dict_iterator()):img = data["image"]fake = net(img)fake = (fake[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((1, 2, 0))img = (img[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((1, 2, 0))fig.add_subplot(2, 8, i+1+a)plt.axis("off")plt.imshow(img.asnumpy())fig.add_subplot(2, 8, i+9+a)plt.axis("off")plt.imshow(fake.asnumpy())eval_data('./CycleGAN_apple2orange/predict/apple', net_rg_a, 0)
eval_data('./CycleGAN_apple2orange/predict/orange', net_rg_b, 4)
plt.show()

结果如下:

更多CycleGAN的内容可参考MindSpore官方的教学视频:

CycleGAN图像风格迁移转换_哔哩哔哩_bilibili


http://www.ppmy.cn/news/1536345.html

相关文章

ssm图书管理系统的设计与实现

系统包含&#xff1a;源码论文 所用技术&#xff1a;SpringBootVueSSMMybatisMysql 免费提供给大家参考或者学习&#xff0c;获取源码请私聊我 需要定制请私聊 目 录 摘 要 I Abstract II 第1章 绪论 1 1.1 课题研究背景 1 1.2课题研究现状 1 1.3课题实现目的和意义 …

基于SSM的坚果金融投资管理系统、坚果金融投资管理平台的设计与开发、智慧金融投资管理系统的设计与实现、坚果金融投资管理系统的设计与应用研究(源码+定制+开发)

博主介绍&#xff1a; ✌我是阿龙&#xff0c;一名专注于Java技术领域的程序员&#xff0c;全网拥有10W粉丝。作为CSDN特邀作者、博客专家、新星计划导师&#xff0c;我在计算机毕业设计开发方面积累了丰富的经验。同时&#xff0c;我也是掘金、华为云、阿里云、InfoQ等平台…

17 链表——21. 合并两个有序链表 ★

17 链表 21. 合并两个有序链表 将两个升序链表合并为一个新的升序链表并返回。新链表是通过拼接给定的两个链表的所有节点组成的。 示例 1: 输入:l1 = [1,2,4], l2 = [1,3,4] 输出:[1,1,2,3,4,4] 算法设计: 合并两个有序链表,并保持有序性,可以采用迭代法和递归法两种…

Hive优化操作(二)

Hive 数据倾斜优化 在使用 Hive 进行大数据处理时&#xff0c;数据倾斜是一个常见的问题。本文将详细介绍数据倾斜的概念、表现、常见场景及其解决方案。 1. 什么是数据倾斜&#xff1f; 数据倾斜是指由于数据分布不均匀&#xff0c;导致大量数据集中到某个节点或任务中&…

【Python】文件及目录

文章目录 概要一、文件对象的函数1.1 open()函数1.2 文件对象的函数1.3 with语句 二、基于os和os.path模块的目录操作三、基于Pandas的文件处理3.1 Pandas读写各种类型文件 其他章节的内容 概要 本文主要将了打开文件的函数open()的参数&#xff0c;以及文件对象的函数&#x…

[大语言模型-算法优化] 微调技术-LoRA算法原理及优化应用详解

[大语言模型-算法优化] 微调技术-LoRA算法原理及优化应用详解 前言: 古人云: 得卧龙者&#xff0c;得天下。 然在当今大语言模型流行的时代&#xff0c;同样有一句普世之言: 会微调技术者&#xff0c;得私域大模型部署之道&#xff01; 在众多微调技术中&#xff0c;LoRA (…

前端编程艺术(4)---JavaScript进阶(vue前置知识)

目录 1.变量和常量 2.模版字符串 3.对象 4.解构赋值 1.数组的解构 2.对象的解构 5.箭头函数 6.数组和对象的方法 7.扩展运算符 8.Web存储 9.Promise 10.AsyncAwait 11.模块化 1.变量和常量 JavaScript 中的变量和常量是用于存储数据的标识符。变量可以被重新赋值&am…

力扣977.有序数组的平方

题目链接&#xff1a;977. 有序数组的平方 - 力扣&#xff08;LeetCode&#xff09; 给你一个按 非递减顺序 排序的整数数组 nums&#xff0c;返回 每个数字的平方 组成的新数组&#xff0c;要求也按 非递减顺序 排序。 示例 1&#xff1a; 输入&#xff1a;nums [-4,-1,0,…