昇思学习打卡营第31天|深度解密 CycleGAN 图像风格迁移:从草图到线稿的无缝转化

news/2024/12/22 0:24:13/
1. 简介

        图像风格迁移是计算机视觉领域中的一个热门研究方向,其中 CycleGAN (循环对抗生成网络) 在无监督领域取得了显著的突破。与传统需要成对训练数据的模型如 Pix2Pix 不同,CycleGAN 不需要严格的成对数据,只需两类图片域数据,便可实现图像风格的迁移与互换。

        本篇博文将通过一个实际案例演示如何使用 CycleGAN 实现从草图到目标线稿图的图像风格迁移任务,并详细介绍 CycleGAN 的模型结构、数据处理及训练过程。

2. 模型介绍

        CycleGAN 的核心思想源自 "Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks" 论文。该模型在不需要成对示例的情况下,学习将源域 X 的图像转换到目标域 Y。其应用领域包括风格迁移、图像增强和域适应等任务。

2.1 CycleGAN 网络结构

        CycleGAN 由两个 GAN 模型组成,其对称的架构允许在不同的域之间来回转换图像。具体而言,CycleGAN 使用两个生成器(G 和 F)和两个判别器(D_X 和 D_Y),生成器负责将域 X 的图像转换到域 Y,并通过判别器对生成结果进行真假判断。

        模型架构如下:

  1. 生成器:生成器采用 ResNet 结构,由 9 个残差块组成,适合处理 256x256 尺寸的图片。
  2. 判别器:判别器通过 PatchGAN 模型检测图像的真实性,以保证生成的图像足够逼真。
2.2 循环一致性损失

        CycleGAN 通过 循环一致性损失 来保证从域 X 到域 Y,再从域 Y 转换回域 X 的图像应尽可能接近原始图像。这种损失机制确保模型不会丢失重要的图像特征。

3. 数据集

        本案例使用的数据集包含线稿图和草图图像,所有图片大小为 256x256 像素。数据集分为训练集和测试集,训练集包含 25654 张图片,测试集包含约 100 张线稿图片和 116 张草图图片。

4. 模型实现
4.1 生成器模型

        生成器模型基于 ResNet 结构,通过卷积、反卷积及残差块实现图像风格的转换。以下是生成器的代码实现:

import mindspore.nn as nnclass ResidualBlock(nn.Cell):def __init__(self, dim):super(ResidualBlock, self).__init__()self.conv_block = nn.SequentialCell(nn.Conv2d(dim, dim, kernel_size=3, padding=1, pad_mode="pad"),nn.BatchNorm2d(dim),nn.ReLU(),nn.Conv2d(dim, dim, kernel_size=3, padding=1, pad_mode="pad"),nn.BatchNorm2d(dim))def construct(self, x):return x + self.conv_block(x)class ResNetGenerator(nn.Cell):def __init__(self, input_nc, output_nc, n_residual_blocks=9):super(ResNetGenerator, self).__init__()model = [nn.Conv2d(input_nc, 64, kernel_size=7, padding=3, pad_mode="pad"),nn.BatchNorm2d(64),nn.ReLU()]# Downsamplingmodel += [nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),nn.BatchNorm2d(128),nn.ReLU(),nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),nn.BatchNorm2d(256),nn.ReLU()]# Residual blocksfor _ in range(n_residual_blocks):model += [ResidualBlock(256)]# Upsamplingmodel += [nn.Conv2dTranspose(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),nn.BatchNorm2d(128),nn.ReLU(),nn.Conv2dTranspose(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),nn.BatchNorm2d(64),nn.ReLU()]model += [nn.Conv2d(64, output_nc, kernel_size=7, padding=3, pad_mode="pad"),nn.Tanh()]self.model = nn.SequentialCell(model)def construct(self, x):return self.model(x)
4.2 判别器模型

        判别器基于 PatchGAN 的结构,通过卷积网络将输入图片划分为多个小的 patch,并分别进行真假判别。

class Discriminator(nn.Cell):def __init__(self, input_nc, ndf=64):super(Discriminator, self).__init__()self.model = nn.SequentialCell([nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1),nn.LeakyReLU(0.2),nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(ndf * 2),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(ndf * 4),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=1, padding=1),nn.BatchNorm2d(ndf * 8),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=1, padding=1)])def construct(self, x):return self.model(x)
4.3 优化器与损失函数

        CycleGAN 采用对抗性损失和循环一致性损失的组合来训练生成器和判别器。优化器选择了 Adam 优化器,学习率设置为 0.0002。

import mindspore as ms# 定义损失函数和优化器
gan_loss = nn.BCELoss()
cycle_loss = nn.L1Loss()optimizer_G = nn.Adam(generator.parameters(), learning_rate=0.0002)
optimizer_D = nn.Adam(discriminator.parameters(), learning_rate=0.0002)
5. 训练与推理

        训练过程中,我们交替训练生成器和判别器。判别器通过真假样本的判别进行训练,而生成器则通过对抗判别和循环一致性进行优化。以下是一个训练步骤的实现:

def train_step(real_A, real_B):# 生成器前向计算fake_B = generator_A2B(real_A)fake_A = generator_B2A(real_B)# 判别器前向计算D_A_loss = gan_loss(discriminator_A(fake_A), Tensor(0)) + gan_loss(discriminator_A(real_A), Tensor(1))D_B_loss = gan_loss(discriminator_B(fake_B), Tensor(0)) + gan_loss(discriminator_B(real_B), Tensor(1))# 生成器损失计算cycle_A_loss = cycle_loss(generator_B2A(fake_B), real_A)cycle_B_loss = cycle_loss(generator_A2B(fake_A), real_B)G_loss = cycle_A_loss + cycle_B_loss + D_A_loss + D_B_lossoptimizer_G.step()optimizer_D.step()return G_loss, D_A_loss, D_B_loss

结语

        通过本次的CycleGAN模型实践,我们深入理解了图像风格迁移的基本原理,特别是在无监督情况下如何实现两个域之间的图像转换。CycleGAN的循环一致性损失在保持图像内容一致性的同时,又能实现风格的转换,这是其在域迁移任务中广泛应用的重要原因。在整个实现过程中,不仅对生成器和判别器的构建有了更清晰的理解,同时也进一步熟悉了损失函数的优化策略。

        这次实验的关键在于让模型具备在没有配对数据的情况下,也能够进行风格转换的能力。虽然实验需要较大的计算资源,但我们通过小规模数据集也能够体验到CycleGAN的强大之处。希望通过这个项目,我们不仅能掌握CycleGAN的基本原理,也能为以后的图像生成和风格迁移任务打下坚实的基础。

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!


http://www.ppmy.cn/news/1535897.html

相关文章

4.循环结构在存储过程中的应用(4/10)

引言 在数据库管理中,存储过程是一种强大的工具,它允许将一组SQL语句封装为一个独立的、可重用的单元。存储过程不仅可以提高数据处理的效率,还可以增强代码的安全性和可维护性。在复杂的数据库操作中,循环结构扮演着至关重要的角…

软考UML图 -- ( 类图,对象图,用例图,序列图,通信图,状态图,活动图,构件图,部署图)

文章目录 一、UML统一建模语言二、关系三、UML图1. 类图2. 对象图3. 用例图4. 序列图(顺序图)—— 交互图5. 通信图 —— 交互图6. 状态图7. 活动图8. 构件图(组件图)9. 部署图10. 总结 一、UML统一建模语言 UML由3个要素构成:UM…

JS高频手写题,看看你会多少

JS手写题 1. 手写 apply 方法 Function.prototype.myApply function (thisArg, argsArray) {// 设置调用函数的上下文对象,如果没有传入上下文,则使用全局对象thisArg thisArg || globalThis;// 创建一个唯一的 Symbol 以避免覆盖原有属性const fn …

Python项目文档生成常用工具对比

写在前面: 通过阅读本片文章,你将了解:主流的Python项目文档生成工具(Sphinx,MkDocs,pydoc,Pdoc)简介及对比,本文档不涉及相关工具的使用。 概述 近期,由于…

永洪科技第八届全国用户大会,释放数据价值!

永洪科技,作为“致力于打造全球领先的数据技术厂商”,将于【2024年11月1日】,在【北京东方君悦大酒店】盛大召开“第八届永洪科技全国用户大会”。旨在通过AIBI的深入融合,更加智能且精准的展现及预测未来的数据走向,展…

Robot Operating System——列有序的位姿(poses)

大纲 应用场景1. 机器人导航场景描述具体应用 2. 环境建模场景描述具体应用 3. 多机器人协作场景描述具体应用 4. 仿真环境场景描述具体应用 5. 传感器数据处理场景描述具体应用 定义字段解释 案例 nav_msgs::msg::Path 详细介绍 nav_msgs::msg::Path 是 ROS 2 中的一个消息类型…

Python和R及Julia妊娠相关疾病生物剖析算法

🎯要点 算法使用了矢量投影、现代优化线性代数、空间分区技术和大数据编程利用相应向量空间中标量积和欧几里得距离的紧密关系来计算使用妊娠相关疾病(先兆子痫)、健康妊娠和癌症测试算法模型使用相关性投影利用相关性和欧几里得距离之间的关…

Studying-多线程学习Part4 - 异步并发——async future、packaged_task、promise

异步并发——async future packaged_task promise 1.async、future 是C11引入的一个函数模版,用于异步执行一个函数,并返回一个future对象,表示异步操作的结果。使用 async 可以方便地进行异步编程,避免了手动创建线程和管理线程…