基于少样本(小样本)的图像分割

news/2024/12/22 14:15:19/

在少样本学习(尤其是少样本分割)中,query branchsupport branch 分别承担不同的任务,目的是通过少量标注的样本来实现对新图像的分割。它们的具体作用如下:

  1. Support Branch(支持分支)

    • 作用:用于处理少量已标注的样本图像(称为支持集),提取该图像中的语义特征并生成类原型(prototypes)。类原型是该类的特征表示,通常是对支持图像中特定类别的特征进行聚合,得到紧凑且具有代表性的特征向量。
    • 目标:从支持集中学习到每个类别的特征表示(类原型),为后续的分割任务提供参考。
  2. Query Branch(查询分支)

    • 作用:用于处理未标注的查询图像,提取其特征并通过与支持分支中的类原型进行匹配,来实现对查询图像的像素级分割。
    • 目标:通过将查询图像中的每个像素与支持分支提供的类原型进行比较,确定该像素属于哪一个类别,进而完成图像分割任务。

总的来说,support branch 负责从少量标注样本中学习类原型,而 query branch 则利用这些类原型对新的查询图像进行分割。两者的协作使得模型能够在少样本的情况下有效泛化。

-------------------------

具体是怎么匹配的?

--------------------------

在少样本学习的图像分割任务中,query branch(查询分支)support branch(支持分支)匹配过程主要基于度量学习(metric learning)思想。下面具体解释这个匹配过程:

  1. 特征提取

    • 支持分支(Support branch)和查询分支(Query branch)都会先经过一个共享的特征提取器(通常是预训练的卷积神经网络,如VGG或ResNet)。这个过程将支持图像和查询图像都映射到相同的嵌入空间(embedding space)中,以获得高维特征表示。
  2. 原型生成

    • 在支持分支中,每个类别的支持样本通过其语义标签(支持图像中的前景和背景)生成相应类别的类原型。类原型是通过平均池化支持样本中的相同类别像素特征来生成的,这些类原型代表了每个类别的中心特征,起到类的“代表性”作用。
  3. 匹配过程

    • 在查询分支中,查询图像的每个像素也映射为一个特征向量。为了进行分割,系统会逐像素地将查询图像中的像素特征与支持分支生成的类原型进行匹配。这个匹配过程常常采用余弦相似度(cosine similarity)或欧氏距离(Euclidean distance)来计算特征之间的相似度。
    • 查询图像中的每个像素将被分配给最相似的类原型(即距离最小的原型),从而确定这个像素属于哪一个类别。
  4. 监督学习与正则化

    • 除了通过支持分支生成类原型进行匹配外,还会引入监督学习机制。在部分方法(如PANet)中,提出了一种原型对齐正则化(Prototype Alignment Regularization),通过反向匹配,即将查询图像的预测分割结果当作新的支持集来进一步微调类原型,以提高模型的泛化能力。

简单示例:

假设有一个少样本学习任务,目标是分割“马”和“草地”,支持图像中标注了“马”和“草地”类。支持分支提取出“马”和“草地”的类原型,查询图像中的每个像素会计算与这两个类原型的相似度,并根据最高相似度为该像素分配“马”或“草地”的标签。

这个匹配过程使得即便只有少量标注样本,模型也能通过这种原型匹配有效地进行图像分割。

------------------------------------

详细的告诉我类原型是如何生成的,比较如如何进行的。可以用代码来解释。

------------------------------------------

类原型的生成与比较是少样本学习中关键的步骤,尤其是在少样本图像分割任务中。类原型(prototype)是从支持图像中提取到的类特征的集中表示,通过比较查询图像的像素特征与这些类原型来完成分割任务。接下来,我将详细解释类原型的生成以及如何进行比较,并通过代码进行示例。

1. 类原型生成

类原型通常是支持图像中每个类别的特征向量的平均值。首先,从支持图像中提取出特征,并使用其标注的掩码信息计算出每个类别的类原型。

步骤:

通过这种方式,少样本学习可以在仅有少量标注样本的情况下实现图像分割任务。

----------------------------

这个高维空间应该指的不是二维空间吧?相似性度量肯定不是在二维空间中那个像素离着原型的距离更近就把他划分为哪个类别这么简单吧?

--------------------------------

是的,您理解得很准确。在少样本学习中,特别是图像分割任务中,相似性度量并不是在简单的二维空间中进行的,而是在高维嵌入空间中完成的。这意味着,我们比较的是特征向量,而不是直接在原始图像的二维像素空间中做简单的几何距离计算。

  1. 提取支持图像的特征: 将支持图像通过预训练的深度网络(如ResNet、VGG等),生成每个像素的特征向量。

  2. 使用掩码信息生成类原型: 对于每个类别,使用支持图像中的掩码(mask)来筛选出对应类别的像素特征,将这些特征的平均值作为该类别的类原型。

  3. import torch
    import torch.nn.functional as Fdef extract_prototype(support_features, support_masks, num_classes):"""生成每个类别的类原型。参数:- support_features: 支持图像的特征,维度为 [batch_size, channels, height, width]- support_masks: 支持图像的掩码,维度为 [batch_size, num_classes, height, width]- num_classes: 类别数返回:- prototypes: 每个类别的类原型,维度为 [num_classes, channels]"""batch_size, channels, height, width = support_features.shape# 初始化类原型prototypes = torch.zeros((num_classes, channels), device=support_features.device)for class_id in range(num_classes):# 获取当前类别的掩码,维度为 [batch_size, height, width]class_mask = support_masks[:, class_id, :, :].unsqueeze(1)  # [batch_size, 1, height, width]# 获取该类的所有像素特征,使用掩码筛选class_features = support_features * class_mask  # [batch_size, channels, height, width]# 计算类原型(求平均值,忽略掩码中为0的地方)sum_features = class_features.sum(dim=(0, 2, 3))  # 求和num_pixels = class_mask.sum(dim=(0, 2, 3)) + 1e-8  # 防止除以0prototypes[class_id] = sum_features / num_pixels  # 平均化得到类原型return prototypes# 示例输入
    support_features = torch.randn((2, 256, 32, 32))  # 假设2张支持图像,256维特征,32x32像素
    support_masks = torch.randint(0, 2, (2, 2, 32, 32)).float()  # 2类掩码,2张图片
    num_classes = 2  # 两类:前景和背景prototypes = extract_prototype(support_features, support_masks, num_classes)
    print(prototypes.shape)  # 输出维度应为 [num_classes, channels] -> [2, 256]
    

    2. 比较过程:查询像素与类原型的匹配

    在获取类原型之后,我们需要将查询图像中的每个像素与这些类原型进行比较,通常使用余弦相似度欧氏距离。具体来说,对于查询图像中的每个像素,将其特征向量与每个类原型进行比较,选择最相似的类原型来预测该像素的类别。

    比较方法1:余弦相似度
  4. def cosine_similarity(a, b):"""计算余弦相似度。参数:- a: 查询图像的像素特征 [height * width, channels]- b: 类原型 [num_classes, channels]返回:- similarities: 查询像素与每个类原型的余弦相似度,维度 [height * width, num_classes]"""a_norm = F.normalize(a, p=2, dim=1)  # 归一化查询像素特征b_norm = F.normalize(b, p=2, dim=1)  # 归一化类原型similarities = torch.mm(a_norm, b_norm.t())  # 计算余弦相似度矩阵return similarities# 示例输入
    query_features = torch.randn((1024, 256))  # 查询图像有1024个像素,每个像素256维特征
    similarities = cosine_similarity(query_features, prototypes)
    predicted_classes = similarities.argmax(dim=1)  # 获取相似度最高的类作为预测类别
    

    比较方法2:欧氏距离

  5. def euclidean_distance(a, b):"""计算欧氏距离。参数:- a: 查询图像的像素特征 [height * width, channels]- b: 类原型 [num_classes, channels]返回:- distances: 查询像素与每个类原型的欧氏距离,维度 [height * width, num_classes]"""a_exp = a.unsqueeze(1)  # [height * width, 1, channels]b_exp = b.unsqueeze(0)  # [1, num_classes, channels]distances = torch.norm(a_exp - b_exp, dim=2)  # 计算每个查询像素到类原型的距离return distances# 示例输入
    distances = euclidean_distance(query_features, prototypes)
    predicted_classes = distances.argmin(dim=1)  # 距离最小的类作为预测类别
    

    3. 总结

  6. 类原型生成:通过支持图像中的特征和其类别掩码,将每个类别的像素特征取平均,生成类原型。
  7. 比较过程:对于查询图像的每个像素,通过计算其特征向量与类原型的相似性(例如余弦相似度或欧氏距离),确定该像素的类别。

高维嵌入空间的作用

在图像分割的少样本学习中,图像的像素并没有被直接用于相似性比较,而是先通过一个深度神经网络(如VGG、ResNet)将图像的每个像素转化为一个高维的特征向量。这个特征向量通常会有数百到上千个维度。这些高维特征是从低层像素信息中提取出的抽象表示,它们能够更好地捕捉图像中的语义信息,比如物体的形状、纹理等。

通过这种方式,系统能够基于深度特征来衡量不同像素之间的相似性,而不仅仅是依赖于像素的原始RGB值或其二维空间位置。

高维空间中的相似性度量

  1. 高维特征向量的比较:在高维嵌入空间中,每个像素通过深度神经网络映射成一个高维向量,代表该像素的特征。在这个特征空间中,不同类别的像素将会聚集在相似的位置,而不同类别的像素会尽量远离。通过将这些特征向量与类原型进行比较(例如通过余弦相似度或欧氏距离),我们可以判断一个查询像素与哪个类最相似。

  2. 度量的复杂性:正如您所提到的,在二维空间中简单的几何距离并不适用。在高维空间中,由于每个特征向量包含的信息非常丰富,特征间的距离不仅反映空间上的接近性,还包括了复杂的语义相似性。例如,余弦相似度可以通过比较两个向量的方向而忽略它们的尺度,适合用于高维特征向量的比较。


http://www.ppmy.cn/news/1536176.html

相关文章

低代码可视化-UniApp二维码可视化-代码生成器

市面上提供了各种各样的二维码组件,做了一简单的uniapp二维码组件,二维码实现依赖davidshimjs/qrcodejs。 组件特点 跨浏览器支持:利用Canvas元素实现二维码的跨浏览器兼容性,兼容微信小程序、h5、app。 无依赖性:QR…

【PyTorch】图像分割

图像分割是什么 Image Segmentation 将图像每一个像素分类 图像分割分类 超像素分割:少量超像素代替大量像素,常用于图像预处理语义分割:逐像素分类,无法区分个体实例分割:对个体目标进行分割全景分割:…

Observability:使用 OpenTelemetry 自动检测 Go 应用程序

作者:来自 Elastic Damien Mathieu 使用 OpenTelemetry 检测 Go 应用程序可以深入了解应用程序的性能、依赖项和错误。我们将向你展示如何使用 Docker 自动检测 Go 应用程序,而无需更改应用程序代码。 在快节奏的软件开发领域,尤其是在云原生…

LSTM--详解

LSTM LSTM(Long Short-Term Memory,长短期记忆网络)是一种特殊的递归神经网络(RNN),最早由Hochreiter和Schmidhuber于1997年提出,旨在解决标准RNN在处理长序列时容易出现的梯度消失和梯度爆炸问…

Oracle数据恢复—异常断电导致Oracle数据库报错的数据恢复案例

Oracle数据库故障: 机房异常断电后,Oracle数据库启库报错:“system01.dbf需要更多的恢复来保持一致性,数据库无法打开”。数据库没有备份,归档日志不连续。用户方提供了Oracle数据库的在线文件,需要恢复zxf…

爬虫入门之爬虫原理以及请求响应

爬虫入门之爬虫原理以及请求响应 爬虫需要用到的库, 叫requests. 在导入requests库之前, 需要安装它, 打开cmd: 输入pip install 库名 pip install requests后面出现successful或requirement already就说明已经下载成功了!!! 下载出现的问题: 1.有报错或者是下载慢 修改镜像…

【爬虫】网站反debugger、内存爆破以及网站限制开发者工具

【爬虫】网站反debugger、内存爆破以及网站直接限制开发者工具 声明:本文中所有内容仅供学习交流使用,不用于其他任何目的,不提供完整代码,敏感网址、数据接口等均已做脱敏处理,严禁用于商业用途和非法用途&#xff0…

力扣16~20题

题16&#xff08;中等&#xff09;&#xff1a; 思路&#xff1a; 双指针法&#xff0c;和15题差不多&#xff0c;就是要排除了&#xff0c;如果total<target则排除了更小的&#xff08;left右移&#xff09;&#xff0c;如果total>target则排除了更大的&#xff08;rig…