MVANet——小范围内捕捉高分辨率细节而在大范围内不损失精度的强大的背景消除模型

devtools/2025/2/3 23:17:18/

一、概述

前景提取(背景去除)是现代计算机视觉的关键挑战之一,在各种应用中的重要性与日俱增。在图像编辑和视频制作中有效地去除背景不仅能提高美学价值,还能提高工作流程的效率。在要求精确度的领域,如医学图像分析和自动驾驶技术中的物体识别,背景去除也发挥着重要作用。主要的挑战是在高分辨率图像中捕捉小区域的精细细节,同时保持大区域的精确度。迄今为止,还没有一种方法能将细节再现与整体精度相结合。然而,一种名为 MVANet 的新方法为这一挑战提供了创新的解决方案。

MVANet 采用的独特方法受到人类视觉的启发。正如人类从多个角度观察物体一样,MVANet 也从多个角度分析物体。这种方法可以在不丢失细节的情况下提高整体精度。此外,多视角的整合还可实现远距离视觉交互,这是传统方法难以实现的。

市场营销、娱乐、医疗保健和安全等各行各业对背景消除技术的需求与日俱增。在网上购物中,它可使产品的前景更加突出,从而提高购买意愿。它对于使用虚拟背景的视频会议应用以及视频制作中绿屏的替代技术也很重要。随着所有这些应用成为焦点,前景提取性能的提高将对整个行业产生重大影响。

这种新方法已经证明了它的有效性。特别是在 DIS-5K 数据集上,它在精度和速度上都优于目前的 SOTA;MVANet 有潜力成为前景提取任务的新标准,并有望在未来获得更广泛的应用。

二、算法架构

图 1:MVANet 概述。

MVANet 的整体结构与 UNet 类似,如图 1 所示。编码器使用一个远景(G)和一个近景(Lm)作为输入,远景和近景由 M(本文中为 M=4)不重叠的局部斑块组成。

G 和 Lm 构成一个多视角补丁序列,分批输入特征提取器,生成多级特征图 Ei(i=1,2,3,4,5)。每个 Ei 包含远景和近景的表示。最高级别的特征图 E5 沿批次维度被分成两组不同的全局和局部特征,并被输入多视图完成定位模块(MCLM,图 2-a)。2-a),并将其输入 MCLM(MCLM,图 2-a)。

该解码器类似于 FPN(Lin et.al, 2017)架构,但在每个解码阶段都插入了一个即时多视图完成细化模块(MCRM,图 2-b)。每个阶段的输出用于重建 SDO 地图(只有前景的地图)和计算损失。图 1 的右下方显示了多视角整合。局部特征合并后输入到 Conv Head,以便与全局特征进行细化和串联。

图 2:MCLM 和 MCRM 架构。

学习的损失函数

如图 1 所示,解码器每一层的输出和最终预测都加入了监督。

具体来说,前者由三个部分组成:ll、lg 和 la,分别代表细化模块中的组合局部表征、全局表征和标记注意图。每个侧输出都需要一个单独的卷积层来获得单通道预测。后者用 lf 表示。这些组件结合使用了二元交叉熵(BCE)损失和加权 IoU 损失,这在大多数分割任务中都很常用。

最终的学习损失函数如下式所示。本文设置 λg=0.3,λh=0.3。

三、试验

数据集和评估指标

数据集

本文使用 DIS5K 基准数据集进行实验。该数据集包含 225 个类别的 5,470 张高分辨率图像(2K、4K 或更大尺寸)。数据集分为三个部分

  • DIS-TR:3 000 幅训练图像。
  • DIS-VD:470 幅验证图像。
  • DIS-TE:2,000 张测试图像,分为四个子集(DIS-TE1、2、3 和 4),每个子集有 500 张图像,几何复杂度依次增加

DIS5K 数据集因其高分辨率图像、详细的结构和出色的注释质量,比其他分割数据集更具挑战性,需要先进的模型来捕捉复杂的细节。

评估指标

采用以下指标评估绩效

  • 最大 F 值:测量准确性和重复性的最大得分,β² 设置为 0.3。
  • 加权 F 值:与 F 值类似,但已加权。
  • 结构相似性测量(Sm):评估预测值与真实值之间的结构相似性,同时考虑领域和对象识别。
  • 电子测量:用于评估像素与图像之间的匹配程度。
  • 平均绝对误差 (MAE):计算预测地图与真实值之间的平均误差。

这些指标有助于了解该模型在识别和分割 DIS5K 数据集中具有复杂结构的物体方面的性能。

实验结果

定量评估

表 1 将拟议的 MVANet 与其他 11 个著名的相关模型(F3Net、GCPANet、PFNet、BSANet、ISDNet、IFA、IS-Net、FPDIS、UDUN、PGNet 和 InSPyReNet)进行了比较。为进行公平比较,输入大小标准化为 1024 × 1024。结果表明,在所有数据集的不同指数上,MVANet 都明显优于其他模型。特别是在 F、Em、Sm 和 MAE 方面,MVANet 分别比 InSPyReNet 高出 2.5%、2.1%、0.5% 和 0.4%。

此外,还评估了 InSPyReNet 和 MVANet 的推理速度。两者都在英伟达 RTX 3090 GPU 上进行了测试。由于采用了简单的单流设计,MVANet 的推理速度达到了 4.6 FPS,而 InSPyReNet 为 2.2 FPS。

表 1.DIS5K 的定量评估。

定性评估

为了直观地展示所提方法的高预测准确性,我们将测试集中所选图像的输出结果可视化。如图 3 所示,即使在复杂的场景中,建议的方法也能准确定位物体并捕捉边缘细节。特别是,建议的方法能够准确区分椅子的完整分割和每个网格的内部,而其他方法则会受到明显的黄色纱布和阴影的干扰(见下行)。

图 3.DIS5K 中的定性评估。

四、代码测试

下载源码

git clone https://github.com/qianyu-dlut/MVANet.git
cd MVANet

环境配置

conda create -n mvanet python==3.8
conda activate mvanet
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
pip install -U openmim
mim install mmcv-full==1.3.17
pip install -r requirements.txt

测试代码:

import os
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from torch.autograd import no_grad
from torchvision import transforms
from model.MVANet import inf_MVANet
import ttach as tta# 参数设置
model_path = 'saved_model/Model_80.pth'  # 修改为你的模型路径
image_directory = 'data/images'  # 修改为你的图像目录路径
output_directory = 'datamasks'  # 预测结果保存路径# 图像变换
img_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])# 定义 TTA 变换
tta_transforms = tta.Compose([tta.HorizontalFlip(),tta.Scale(scales=[0.75, 1, 1.25], interpolation='bilinear', align_corners=False),
])def load_model(model_path):net = inf_MVANet().cuda()# 加载模型参数pretrained_dict = torch.load(model_path, map_location='cuda')model_dict = net.state_dict()pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}model_dict.update(pretrained_dict)net.load_state_dict(model_dict)net.eval()return netdef predict_image(net, img_path):# 加载图像并进行预处理img = Image.open(img_path).convert('RGB')w_, h_ = img.sizeimg_resize = img.resize([1024, 1024], Image.BILINEAR)img_var = img_transform(img_resize).unsqueeze(0).cuda()# 预测结果masks = []with no_grad():for transformer in tta_transforms:img_transformed = transformer.augment_image(img_var)model_output = net(img_transformed)deaug_mask = transformer.deaugment_mask(model_output)masks.append(deaug_mask)prediction = torch.mean(torch.stack(masks, dim=0), dim=0).sigmoid()# 将预测结果转换为图像prediction_img = transforms.ToPILImage()(prediction.data.squeeze(0).cpu())prediction_img = prediction_img.resize((w_, h_), Image.BILINEAR)return img, prediction_imgdef process_directory(net, image_dir, output_dir):# 创建保存目录if not os.path.exists(output_dir):os.makedirs(output_dir)# 遍历目录中的所有图像image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]for idx, image_file in enumerate(image_files):img_path = os.path.join(image_dir, image_file)print(f"Processing {idx + 1}/{len(image_files)}: {img_path}")# 预测并显示结果original_img, prediction_img = predict_image(net, img_path)# 保存预测结果prediction_path = os.path.join(output_dir, f"prediction_{image_file}")prediction_img.save(prediction_path)# 显示图像fig, axs = plt.subplots(1, 2, figsize=(10, 5))axs[0].imshow(original_img)axs[0].set_title("Original Image")axs[0].axis('off')axs[1].imshow(prediction_img, cmap='gray')axs[1].set_title("Predicted Mask")axs[1].axis('off')plt.show()if __name__ == '__main__':# 加载模型与处理图像目录model = load_model(model_path)process_directory(model, image_directory, output_directory)

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

五、总结

在这篇评论文章中,我们将高精度前景提取(背景去除)建模为一个多视角物体识别问题,提供了一个高效、简单的多视角聚合网络。这样做的目的是更好地平衡模型设计、准确性和推理速度。
为解决多视图的目标对准问题,提出了多视图完成定位模块,以联合计算目标的共同关注区域。此外,提出的多视图完成细化模块被嵌入到每个解码器块中,以充分整合互补的本地信息,减少单视图补丁中语义的缺失。这样,只需一个卷积层就能实现最终的视图细化。
广泛的实验表明,所提出的方法性能良好。MVANet 有潜力成为前景提取任务的新标准,并有望在未来得到更广泛的应用。

源码下载地址:https://download.csdn.net/download/matt45m/90335556


http://www.ppmy.cn/devtools/155848.html

相关文章

Java基础——分层解耦——IOC和DI入门

目录 三层架构 Controller Service Dao ​编辑 调用过程 面向接口编程 分层解耦 耦合 内聚 软件设计原则 控制反转 依赖注入 Bean对象 如何将类产生的对象交给IOC容器管理? 容器怎样才能提供依赖的bean对象呢? 三层架构 Controller 控制…

DeepSeek R1与OpenAI o1深度对比

文章目录 引言技术原理DeepSeek R1OpenAI o1 性能表现官方数据推理任务知识密集型任务通用能力 价格对比应用场景科研与技术开发自然语言处理(NLP)企业智能化升级教育与培训数据分析与智能决策 部署与集成DeepSeek R1OpenAI o1 伦理考量DeepSeek R1OpenA…

Spark Streaming的背压机制的原理与实现代码及分析

Spark Streaming的背压机制是一种根据JobScheduler反馈的作业执行信息来动态调整Receiver数据接收率的机制。 在Spark 1.5.0及以上版本中,可以通过设置spark.streaming.backpressure.enabled为true来启用背压机制。当启用背压机制时,Spark Streaming会自…

Springboot集成WebFlux响应式开发详解

下面从Spring WebFlux集成依赖开始WebFlux实际应用场景WebFlux完整示例(以POST方式为例)总结Spring Framework 中包含的原始 Web 框架 Spring Web MVC 是专门为 Servlet API 和 Servlet 容器构建的。响应式堆栈 Web 框架 Spring WebFlux 是在 5.0 版本中添加的。它是完全非阻…

OpenAI的真正对手?DeepSeek-R1如何用强化学习重构LLM能力边界——DeepSeek-R1论文精读

2025年1月20日,DeepSeek-R1 发布,并同步开源模型权重。截至目前,DeepSeek 发布的 iOS 应用甚至超越了 ChatGPT 的官方应用,直接登顶 AppStore。 DeepSeek-R1 一经发布,各种资讯已经铺天盖地,那就让我们一起…

第九章:内存池的调整与测试

目录 第一节:线程私有ThreadCache 第二节:线程申请/释放内存的函数 2-1.ConcurrentAlloc 2-2.ConcurrentFree 第三节:测试优化 第四节:基数树优化 第五节:再次测试 第六节:下期预告 第一节&#xff1…

使用CSS实现一个加载的进度条

文章目录 使用CSS实现一个加载的进度条一、引言二、步骤一:HTML结构与CSS基础样式1、HTML结构2、CSS基础样式 三、步骤二:添加动画效果1、使用CSS动画2、结合JavaScript控制动画 四、使用示例五、总结 使用CSS实现一个加载的进度条 一、引言 在现代网页…

Rust 的基本类型有哪些,他们存在堆上还是栈上,是否可以COPY?

Rust 的基本类型主要包括以下几类: 1. 整数类型(Integer) Rust 提供了有符号和无符号的整数类型: 有符号整数(i8, i16, i32, i64, i128, isize)无符号整数(u8, u16, u32, u64, u128, usize&a…