使用PyTorch进行图像风格迁移:基于VGG19实现

ops/2024/10/21 3:18:00/

图像风格迁移(Neural Style Transfer, NST)是深度学习中一个令人着迷的应用,它能够将一张图像的风格应用到另一张图像上。例如,能够将梵高的画风应用到一张普通照片上。本文将详细解释如何使用PyTorch进行风格迁移,逐步分析代码,并讲解其中的关键技术。

1. 环境准备

在开始之前,确保安装了必要的库:

pip install torch torchvision pillow

2. 模型缓存目录设置

为了加速模型的加载,我们可以通过设置环境变量TORCH_HOME来指定模型缓存目录,避免每次运行代码时重新下载模型:

python">os.environ['TORCH_HOME'] = './model_directory'  # 你可以根据需要自定义目录

3. 加载图像

加载图像并进行预处理是风格迁移中的重要步骤。我们需要将图像转换为张量并进行归一化处理,以便与预训练的VGG19模型匹配:

python">def load_image(image_path, max_size=400):image = Image.open(image_path).convert('RGB')size = min(max_size, max(image.size))transform = transforms.Compose([transforms.Resize((size, size)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])return transform(image).unsqueeze(0)

在这里,我们将图像调整为不大于400像素的正方形,并将其转换为适合VGG19模型输入的格式。

4. VGG19模型的特征提取

风格迁移的核心思想是将内容图像的高层次特征与风格图像的低层次特征结合。我们使用VGG19模型的前21层来提取图像的特征:

python">class VGG(nn.Module):def __init__(self):super(VGG, self).__init__()self.features = vgg19(pretrained=True).features[:21].eval()def forward(self, x):features = []for i, layer in enumerate(self.features):x = layer(x)if i in {0, 5, 10, 19, 21}:features.append(x)return features

5. 内容与风格损失

内容损失衡量生成图像与内容图像的特征差异,而风格损失则是基于Gram矩阵来衡量生成图像与风格图像的差异。

  • 内容损失:
python">class ContentLoss(nn.Module):def __init__(self, target):super(ContentLoss, self).__init__()self.target = target.detach()def forward(self, input):return nn.functional.mse_loss(input, self.target)
  • 风格损失:
python">class StyleLoss(nn.Module):def __init__(self, target):super(StyleLoss, self).__init__()self.target = self.gram_matrix(target).detach()def gram_matrix(self, input):batch_size, channels, height, width = input.size()features = input.view(batch_size * channels, height * width)G = torch.mm(features, features.t())return G.div(batch_size * channels * height * width)def forward(self, input):G = self.gram_matrix(input)return nn.functional.mse_loss(G, self.target)

6. 图像风格迁移算法

核心算法将内容图像初始化为输入图像,并通过多次迭代优化,使其逐步接近目标风格图像,同时保持内容的完整性。我们使用LBFGS优化器来实现这一过程:

python">def style_transfer(content_img, style_img, num_steps=1000, style_weight=1e9, content_weight=1):device = torch.device("cuda" if torch.cuda.is_available() else "cpu")content_img = content_img.to(device)style_img = style_img.to(device)model = VGG().to(device)style_features = model(style_img)content_features = model(content_img)input_img = content_img.clone().requires_grad_(True).to(device)optimizer = optim.LBFGS([input_img])style_losses = []content_losses = []for sf, cf in zip(style_features, content_features):content_losses.append(ContentLoss(cf))style_losses.append(StyleLoss(sf))run = [0]while run[0] <= num_steps:def closure():optimizer.zero_grad()input_features = model(input_img)content_loss = 0style_loss = 0for cl, input_f in zip(content_losses, input_features):content_loss += content_weight * cl(input_f)for sl, input_f in zip(style_losses, input_features):style_loss += style_weight * sl(input_f)loss = content_loss + style_lossloss.backward()run[0] += 1if run[0] % 50 == 0:print(f'Step {run[0]}, Content Loss: {content_loss.item():4f}, Style Loss: {style_loss.item():4f}')return lossoptimizer.step(closure)return input_img

7. 结果保存

生成的图像需要去除归一化并保存为常规图片格式:

python">def save_image(tensor, path):image = tensor.clone().detach()image = image.squeeze(0)image = transforms.ToPILImage()(image)image.save(path)

8. 主函数执行

整个过程可以通过主函数来执行,加载图像、进行风格迁移并保存结果:

python">if __name__ == '__main__':content_image_path = 'content_image.png'style_image_path = 'style_image.png'output_image_path = 'output_image.jpg'content_img = load_image(content_image_path)style_img = load_image(style_image_path)result = style_transfer(content_img, style_img)save_image(result, output_image_path)print(f"风格迁移完成,图像已保存为 {output_image_path}")

总结

本文展示了如何使用PyTorch和VGG19模型实现图像风格迁移。通过合理设置内容和风格损失的权重,我们可以生成既保留内容图像结构又具有风格图像艺术风格的全新图像。

完整代码

github:https://github.com/Yolumia/Image_style_transfer_base_vgg19/


http://www.ppmy.cn/ops/111020.html

相关文章

黑链、黑帽、明链分别是什么意思

一、黑链 • 定义&#xff1a;黑链是指通过非法手段&#xff08;如黑客入侵等&#xff09;获取的隐藏链接。这些链接通常被隐藏在网站页面中&#xff0c;普通用户在浏览网页时难以察觉&#xff0c;但搜索引擎可以抓取到。 • 危害&#xff1a;黑链的存在会影响搜索引擎的公正…

从OracleCloudWorld和财报看Oracle的转变

2024年9月9-12日Oracle Cloud World在美国拉斯维加斯盛大开幕 押注AI和云 Oracle 创始人Larry Ellison做了对Oracle战略和未来愿景的主旨演讲&#xff0c;在演讲中Larry将AI技术和云战略推到了前所未有的高度&#xff0c;从新的Oracle 23c改名到Oracle23ai&#xff0c;到Oracl…

SQL server 日常运维命令

一、基础命令 查看当前数据库的版本 SELECT VERSION;查看服务器部分特殊信息 select SERVERPROPERTY(Nedition) as Edition --数据版本&#xff0c;如企业版、开发版等,SERVERPROPERTY(Ncollation) as Collation --数据库字符集,SERVERPROPERTY(Nservername) as Serve…

C语言-数据结构 有向图拓扑排序TopologicalSort(邻接表存储)

拓扑排序算法的实现还是比较简单的&#xff0c;我们需要用到一个顺序栈辅助&#xff0c;采用邻接表进行存储&#xff0c;顶点结点存储入度、顶点信息、指向邻接结点的指针&#xff0c;算法过程是&#xff1a;我们先将入度为0的顶点入栈&#xff0c;然后弹出栈顶结点&#xff0c…

Android 蓝牙服务启动

蓝牙是Android设备中非常常见的一个feature&#xff0c;设备厂家可以用BT来做RC、连接音箱、设备本身做Sink等常见功能。如果一些设备不需要BT功能&#xff0c;Android也可以通过配置来disable此模块&#xff0c;方便厂家为自己的设备做客制化。APP操作设备的蓝牙功能&#xff…

Python | Leetcode Python题解之第400题第N位数字

题目&#xff1a; 题解&#xff1a; class Solution:def findNthDigit(self, n: int) -> int:d, count 1, 9while n > d * count:n - d * countd 1count * 10index n - 1start 10 ** (d - 1)num start index // ddigitIndex index % dreturn num // 10 ** (d - d…

php 实现JWT

在 PHP 中&#xff0c;JSON Web Token (JWT) 是一种开放标准 (RFC 7519) 用于在各方之间作为 JSON 对象安全地传输信息。JWT 通常用于身份验证系统&#xff0c;如 OAuth2 或基于令牌的身份验证。 以下是一个基本的 PHP 实现 JWT 生成和验证的代码示例。 JWT 的组成部分 JWT …

【论文阅读】视觉分割新SOTA: Segment Anything(SAM)

导言 随着基于对比文本—图像对的预训练&#xff08;CLIP&#xff09;方法或者模型、聊天生成预训练转换器&#xff08;ChatGPT&#xff09;、生成预训练转换器-4&#xff08;GPT-4&#xff09;等基础大模型的出现&#xff0c;通用人工智能&#xff08; AGI&#xff09;的研究…