李沐49_样式迁移——自学笔记

样式迁移

将样式图片中的样式迁移到内容图片上,合成图片,例如将照片转换成漫画形式或者是油画风。

基于CNN的样式迁移

读取图片和样式风格

%matplotlib inline
import torch
import torchvision
from torch import nn
from d2l import torch as d2ld2l.set_figsize()
content_img = d2l.Image.open('rainier.jpg')
d2l.plt.imshow(content_img);

在这里插入图片描述

style_img = d2l.Image.open('autumn-oak.jpg')
d2l.plt.imshow(style_img);

在这里插入图片描述

预处理和后处理

预处理函数preprocess对输入图像在RGB三个通道分别做标准化,并将结果变换成卷积神经网络接受的输入格式。

后处理函数postprocess则将输出图像中的像素值还原回标准化之前的值。 由于图像打印函数要求每个像素的浮点数值在0~1之间,我们对小于0和大于1的值分别取0和1。

rgb_mean = torch.tensor([0.485, 0.456, 0.406])
rgb_std = torch.tensor([0.229, 0.224, 0.225])def preprocess(img, image_shape):transforms = torchvision.transforms.Compose([torchvision.transforms.Resize(image_shape),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)])return transforms(img).unsqueeze(0)def postprocess(img):img = img[0].to(rgb_std.device)img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1)return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))

抽取图像特征

基于ImageNet数据集预训练的VGG-19模型来抽取图像特征

pretrained_net = torchvision.models.vgg19(pretrained=True)
/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.warnings.warn(
/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG19_Weights.IMAGENET1K_V1`. You can also use `weights=VGG19_Weights.DEFAULT` to get the most up-to-date weights.warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:03<00:00, 155MB/s]

从VGG中选择不同层的输出来匹配局部和全局的风格,这些图层也称为风格层。VGG网络使用了5个卷积块。

实验中,我们选择第四卷积块的最后一个卷积层作为内容层,选择每个卷积块的第一个卷积层作为风格层。 这些层的索引可以通过打印pretrained_net实例获取。

style_layers, content_layers = [0, 5, 10, 19, 28], [25]

使用VGG层抽取特征时,我们只需要用到从输入层到最靠近输出层的内容层或风格层之间的所有层。 下面构建一个新的网络net,它只保留需要用到的VGG的所有层。

net = nn.Sequential(*[pretrained_net.features[i] for i inrange(max(content_layers + style_layers) + 1)])

给定输入X,如果我们简单地调用前向传播net(X),只能获得最后一层的输出。 由于我们还需要中间层的输出,因此这里我们逐层计算,并保留内容层和风格层的输出。

def extract_features(X, content_layers, style_layers):contents = []styles = []for i in range(len(net)):X = net[i](X)if i in style_layers:styles.append(X)if i in content_layers:contents.append(X)return contents, styles

get_contents函数对内容图像抽取内容特征;

get_styles函数对风格图像抽取风格特征。

因为在训练时无须改变预训练的VGG的模型参数,所以我们可以在训练开始之前就提取出内容特征和风格特征。由于合成图像是风格迁移所需迭代的模型参数,我们只能在训练过程中通过调用extract_features函数来抽取合成图像的内容特征和风格特征。

def get_contents(image_shape, device):content_X = preprocess(content_img, image_shape).to(device)contents_Y, _ = extract_features(content_X, content_layers, style_layers)return content_X, contents_Ydef get_styles(image_shape, device):style_X = preprocess(style_img, image_shape).to(device)_, styles_Y = extract_features(style_X, content_layers, style_layers)return style_X, styles_Y

损失函数

内容损失、风格损失和全变分损失3部分组成

1.内容损失

def content_loss(Y_hat, Y):# 我们从动态计算梯度的树中分离目标:# 这是一个规定的值,而不是一个变量。return torch.square(Y_hat - Y.detach()).mean()

2.风格损失

def gram(X):num_channels, n = X.shape[1], X.numel() // X.shape[1]X = X.reshape((num_channels, n))return torch.matmul(X, X.T) / (num_channels * n)
def style_loss(Y_hat, gram_Y):return torch.square(gram(Y_hat) - gram_Y.detach()).mean()

3.全变分损失

def tv_loss(Y_hat):return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())

损失函数

风格转移的损失函数是内容损失、风格损失和总变化损失的加权和。

content_weight, style_weight, tv_weight = 1, 1e3, 10def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):# 分别计算内容损失、风格损失和全变分损失contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(contents_Y_hat, contents_Y)]styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(styles_Y_hat, styles_Y_gram)]tv_l = tv_loss(X) * tv_weight# 对所有损失求和l = sum(10 * styles_l + contents_l + [tv_l])return contents_l, styles_l, tv_l, l

初始化合成图像

定义一个简单的模型SynthesizedImage,并将合成的图像视为模型参数。模型的前向传播只需返回模型参数即可

class SynthesizedImage(nn.Module):def __init__(self, img_shape, **kwargs):super(SynthesizedImage, self).__init__(**kwargs)self.weight = nn.Parameter(torch.rand(*img_shape))def forward(self):return self.weight

定义get_inits函数。该函数创建了合成图像的模型实例,并将其初始化为图像X。风格图像在各个风格层的格拉姆矩阵styles_Y_gram将在训练前预先计算好。

def get_inits(X, device, lr, styles_Y):gen_img = SynthesizedImage(X.shape).to(device)gen_img.weight.data.copy_(X.data)trainer = torch.optim.Adam(gen_img.parameters(), lr=lr)styles_Y_gram = [gram(Y) for Y in styles_Y]return gen_img(), styles_Y_gram, trainer

训练模型

def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y)scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8)animator = d2l.Animator(xlabel='epoch', ylabel='loss',xlim=[10, num_epochs],legend=['content', 'style', 'TV'],ncols=2, figsize=(7, 2.5))for epoch in range(num_epochs):trainer.zero_grad()contents_Y_hat, styles_Y_hat = extract_features(X, content_layers, style_layers)contents_l, styles_l, tv_l, l = compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)l.backward()trainer.step()scheduler.step()if (epoch + 1) % 10 == 0:animator.axes[1].imshow(postprocess(X))animator.add(epoch + 1, [float(sum(contents_l)),float(sum(styles_l)), float(tv_l)])return X

我们训练模型: 首先将内容图像和风格图像的高和宽分别调整为300和450像素,用内容图像来初始化合成图像。

device, image_shape = d2l.try_gpu(), (300, 450)
net = net.to(device)
content_X, contents_Y = get_contents(image_shape, device)
_, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)

在这里插入图片描述


http://www.ppmy.cn/ops/2362.html

相关文章

【mac】【python】新建项目虚拟环境后,使用命令pip出现错误:zsh: command not found: pip

【mac】【python】新建项目虚拟环境后&#xff0c;使用命令pip出现错误&#xff1a;zsh: command not found: pip 问题描述&#xff1a; 拉取或者创建新的python项目时&#xff0c;为项目添加了新的解释器&#xff0c;创建啦虚拟环境&#xff0c;但是执行pip命令的时候找不到命…

python之flask安装以及使用

1 flask介绍 Flask是一个非常小的Python Web框架&#xff0c;被称为微型框架&#xff1b;只提供了一个稳健的核心&#xff0c;其他功能全部是通过扩展实现的&#xff1b;意思就是我们可以根据项目的需要量身定制&#xff0c;也意味着我们需要学习各种扩展库的使用。 2 python…

计算机组成原理【CO】Ch5 中央处理器

目录 大纲 一条指令的执行 取指令 执行指令 数据传送类&#xff08;mov、load、store&#xff09; 运算类指令&#xff08;加、减、乘、除、移位、与、或&#xff09; 转移类指令&#xff08;jmp、jxxx&#xff09; 如何看懂注释 袁版注释⻛格&#xff08;16年以后的真题&…

前端设置图片以中心向周围展开呈现出来,不压缩宽度

在一个限制了长宽的div设置背景图片或者添加一张图片时&#xff0c;往往会遇到图片被压缩&#xff0c;所以这个时候选择以图片中心出发裁剪出来呈现是一个不错的选择。 现有一个img标签&#xff0c;网页端的样式是图片是完全呈现出来的&#xff0c;手机端的样式会被进行压缩。 …

Uniapp+基于百度智能云完成AI视觉功能(附前端思路)

本博客使用uniapp百度智能云图像大模型中的AI视觉API&#xff08;本文以物体检测为例&#xff09;完成了一个简单的图像识别页面&#xff0c;调用百度智能云API可以实现快速训练模型并且部署的效果。 uniapp百度智能云AI视觉页面实现 先上效果图实现过程百度智能云Easy DL训练图…

Ubuntu22.04.4 - 安装后使用笔记目录-VMware

安装的话就傻瓜式盲点&#xff0c;根据自己需求进行处理&#xff0c;我是在ssh的地方勾选了一下选项&#xff0c;其他都是默认项&#xff0c;官网上有文档&#xff0c;就不赘述了 一、登录用户管理 二、系统命令 三、vim 四、网络配置 五、apt 六、SSH 七、MySQL8

SpringBoot中全局异常捕获与参数校验的优雅实现

一&#xff0c;为什么要用全局异常处理&#xff1f; 在日常开发中&#xff0c;为了不抛出异常堆栈信息给前端页面&#xff0c;每次编写Controller层代码都要尽可能的catch住所有service层、dao层等异常&#xff0c;代码耦合性较高&#xff0c;且不美观&#xff0c;不利于后期维…

力扣经典150题第三十题:长度最小的子数组

目录 力扣经典150题解析之三十&#xff1a;长度最小的子数组1. 介绍2. 问题描述3. 示例4. 解题思路方法一&#xff1a;滑动窗口 5. 算法实现6. 复杂度分析7. 测试与验证测试用例设计测试结果分析 8. 进阶9. 总结10. 参考文献感谢阅读 力扣经典150题解析之三十&#xff1a;长度最…

探索异常传播:深入剖析Python中的错误处理机制

文章目录 1. 异常传播的基本原理2. 复杂的异常传播场景3. 再次抛出异常的意义是什么&#xff1f;4. 最佳实践与异常处理策略 理解异常传播&#xff08;也称为异常冒泡&#xff09;的过程是至关重要的。这一机制确保当在程序执行中发生错误时&#xff0c;错误能被有效地捕获和处…

【蓝桥杯2025备赛】素数判断:从O(n^2)到O(n)学习之路

素数判断:从O( n 2 n^2 n2)到O(n)学习之路 背景:每一个初学计算机的人肯定避免不了碰到素数&#xff0c;素数是什么&#xff0c;怎么判断&#xff1f; 素数的概念不难理解:素数即质数&#xff0c;指的是在大于1的自然数中&#xff0c;除了1和它本身不再有其他因数的自然数。 …

Ubuntu Server 20.04 LTS 64bit安装ftp服务

1.安装vsftpd sudo apt install vsftpd2.配置vsftpd sudo vim /etc/vsftpd.conf write_enableYES # 启用任何形式的FTP写入命令&#xff0c;即可以修改文件local_umask022 # 本地用户创建文件的 umask 值&#xff0c;默认是被注释的connect_from_port_20YES # 针对 PORT 类型…

排序算法之选择排序

目录 一、简介二、代码实现三、应用场景 一、简介 算法平均时间复杂度最好时间复杂度最坏时间复杂度空间复杂度排序方式稳定性选择排序O(n^2 )O(n^2)O(n^2)O(1)In-place不稳定 稳定&#xff1a;如果A原本在B前面&#xff0c;而AB&#xff0c;排序之后A仍然在B的前面&#xff1…

MySQL表结构的操作

文章目录 1. 创建表2. 查看表3. 修改表4. 删除表 1. 创建表 create table table_name (field1 datatype,field2 datatype,field3 datatype )character set 字符集 collate 校验集 engine 存储引擎;field&#xff1a;列名datatype&#xff1a;列的类型character set&#xff1a…

OpenHarmony南向之编译构建框架

title: OpenHarmony南向之编译构建框架 categories:- OpenHarmony tags:- Ninja- GN author: shell cover: /img/oh/build_framework_ZN.png date: 2024-01-03 21:05:38概述 OpenHarmony编译子系统是以GN和Ninja构建为基座&#xff0c;对构建和配置粒度进行部件化抽象、对内建…

C++格式化输出开源库fmt入手教程

fmt项目快速上手指南 1. cmake环境配置 include(FetchContent) FetchContent_Declare(fmtGIT_REPOSITORY https://github.com/fmtlib/fmtGIT_TAG 10.0.0GIT_SHALLOW TRUE) # 1. 下载fmt库 FetchContent_MakeAvailable(fmt)add_executable(fmt_guide main.cpp) # 2. 链接fmt库…

性能优化工具

CPU 优化的各类工具 network netperf 服务端&#xff1a; $ netserver Starting netserver with host IN(6)ADDR_ANY port 12865 and family AF_UNSPEC$ cat netperf.sh #!/bin/bash count$1 for ((i1;i<count;i)) doecho "Instance:$i-------"# 下方命令可以…

C#创建背景色渐变窗体的方法:创建特殊窗体

目录 1.让背景渐变色的理论基础 2.让背景渐变色的方法 3.一个实施例 &#xff08;1&#xff09;Form1.Designer.cs &#xff08;2&#xff09;Form1.cs &#xff08;3&#xff09;渐变的蓝色背景 在窗体设计时&#xff0c;可以通过设置窗体的BackColor属性来改变窗口的背…

Linux系统编程---线程同步

一、同步概念 同步即协同步调&#xff0c;按预定的先后次序运行。 协同步调&#xff0c;对公共区域数据【按序】访问&#xff0c;防止数据混乱&#xff0c;产生与时间有关的错误。 数据混乱的原因&#xff1a; 资源共享(独享资源则不会)调度随机(意味着数据访问会出现竞争)线…

Android USB TP方向修改

搜集的一些关于Android USB TP的方向修改的代码&#xff0c;X to Y , X反转 &#xff0c; Y反转&#xff0c;双触屏配置&#xff0c;双屏异触等。 diff --git a/kernel/drivers/hid/hid-multitouch.c b/kernel/drivers/hid/hid-multitouch.c old mode 100644new mode 100755 i…

Docker安装部署Jenkins并发布NetCore应用

Docker安装Jenkins # 拉取镜像 docker pull jenkins/jenkins # 查看镜像 docker images # 运行jenkins # 8080端口为jenkins Web 界面的默认端口 13152是映射到外部 &#xff1a;前面的是映射外部 # 50000端口为jenkins 的默认代理节点&#xff08;Agent&#xff09;通信端口…