计算机视觉|超详细!Meta视觉大模型Segment Anything(SAM)源码解剖

devtools/2025/3/15 3:22:18/

一、引言

计算机视觉领域,图像分割是一个核心且具有挑战性的任务,旨在将图像中的不同物体或区域进行划分和识别,广泛应用于自动驾驶、医学影像分析、安防监控等领域。Segment Anything Model(SAM)由 Meta AI 实验室发布,其引入了基于 Prompt 的交互式分割能力,显著提升了图像分割的灵活性和泛化能力。

SAM 通过在海量且多样化的数据集上训练,具备处理未见过对象类别和场景的能力。这一特性使其在学术界引发广泛研究,如模型轻量化、领域适应、多模态融合等方向;在工业界也迅速应用于医学图像分析中的肿瘤检测、遥感图像处理中的卫星图像分析以及视频处理中的目标跟踪等领域。

对于计算机视觉爱好者和开发者而言,深入剖析 SAM 的源码有助于理解其核心技术原理,为进一步创新和应用提供基础。本文将从原理到源码逐步解析 SAM 的工作机制,探索其实现高效图像编码、提示编码及掩码解码的过程。

二、SAM 原理速览

(一)核心概念

SAM 的核心在于其 “分割一切” 的能力,这一能力依赖于基于 Prompt 的分割策略。Prompt 可以是点(points)、框(boxes)、掩码(masks)或文本(text),为模型提供分割目标的关键信息。例如,点击图像中的一个点,SAM 能够识别并分割该点所在物体;给定一个框,SAM 则专注于框内物体的分割。

(二)模型架构

SAM 的架构包含三个核心组件:Image Encoder(图像编码器)Prompt Encoder(提示编码器)Mask Decoder(掩码解码器),其协作方式如下图所示:
图 1:SAM 模型架构示意图,展示图像编码、提示编码及掩码解码的协作流程

  • Image Encoder:将输入图像转换为高维特征表示,通常使用预训练的 Vision Transformer(ViT),如 ViT-H、ViT-L、ViT-B。例如,对于分辨率为 1024 × 1024 1024 \times 1024 1024×1024 的图像,经过 Patch Embedding 操作划分为 16 × 16 16 \times 16 16×16 的 patches,特征图尺寸缩小为原来的 1 16 \frac{1}{16} 161,通道数从 3 映射到 768。
  • Prompt Encoder:处理不同类型的 Prompt,将其编码为与图像嵌入兼容的特征。对点和框使用位置编码(Positional Encoding);对文本使用 CLIP 文本编码器;对掩码则通过轻量级卷积网络编码。
  • Mask Decoder:融合图像嵌入和提示嵌入,通过 Transformer 结构解码为分割掩码,默认生成 3 个候选掩码并按置信度排序。

三、源码结构总览

(一)代码目录解析

SAM 的代码目录结构如下:

segment-anything/
├── assets              # 示例图片等资源
├── demo                # 前端部署代码
├── notebooks           # Jupyter Notebook 示例
├── script              # 模型导出脚本
├── segment_anything    # 核心代码目录
│   ├── build_sam.py    # 模型构建脚本
│   ├── config.py       # 配置文件
│   ├── mask_decoder.py # 掩码解码器实现
│   ├── model_registry.py # 模型注册模块
│   ├── predictor.py    # 预测接口
│   ├── sam.py          # SAM 整体结构
│   ├── sam_arch.py     # 架构细节
│   ├── utils.py        # 工具函数
│   └── automatic_mask_generator.py # 自动掩码生成
└── setup.py            # 安装脚本

segment_anything 是核心目录,后续分析将聚焦于此。

(二)关键文件与模块

  • build_sam.py:定义模型构建函数,支持不同版本(如 vit_h、vit_l、vit_b)。示例代码:
def build_sam_vit_h(checkpoint=None):# 构建 vit_h 版本的 SAM 模型return _build_sam(encoder_embed_dim=1280,           # 编码器嵌入维度encoder_depth=32,                 # 编码器层数encoder_num_heads=16,             # 注意力头数encoder_global_attn_indexes=[7, 15, 23, 31],  # 全局注意力层索引checkpoint=checkpoint,            # 预训练权重文件路径)
  • predictor.py:提供预测接口,set_image 处理图像预处理,predict 根据提示生成掩码。核心代码:
def set_image(self, image: np.ndarray) -> None:# 检查输入图像维度和通道数是否符合要求if image.ndim != 3 or image.shape[2] not in [3, 4]:raise ValueError("Image must be 3D with 3 or 4 channels")# 应用图像变换(如缩放、归一化)input_image = self.transform.apply_image(image)# 转换为 PyTorch 张量并调整维度为 [1, C, H, W]input_image_torch = torch.as_tensor(input_image, device=self.device)self.set_torch_image(input_image_torch.permute(2, 0, 1)[None, :, :, :], image.shape[:2])
  • automatic_mask_generator.py:自动生成所有物体掩码,基于点提示网格。核心代码:
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:# 预处理输入图像input_image = self.model.preprocess(image)# 使用无梯度计算图像嵌入with torch.no_grad():image_embedding = self.model.image_encoder(input_image)# 生成点提示网格points = self._generate_points(image.shape[:2])all_masks, all_scores = [], []# 按批次处理点提示并预测掩码for i in range(0, len(points), self.points_per_batch):batch_points = points[i:i + self.points_per_batch]batch_masks, batch_scores = self._predict_masks(image_embedding, batch_points)all_masks.extend(batch_masks)all_scores.extend(batch_scores)# 返回掩码和对应置信度列表return [{"segmentation": m, "score": s} for m, s in zip(all_masks, all_scores)]

四、核心代码深度剖析

(一)Image Encoder 源码解析

  • Patch Embedding:将图像划分为 patches 并映射为特征向量:
class PatchEmbed(nn.Module):def __init__(self, kernel_size=(16, 16), stride=(16, 16), in_chans=3, embed_dim=768):# 初始化 Patch Embedding 模块super().__init__()# 定义卷积层,将图像划分为 patches 并映射到嵌入维度self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size, stride)def forward(self, x: torch.Tensor) -> torch.Tensor:# 对输入图像进行卷积操作,生成特征图x = self.proj(x)# 调整维度顺序为 [B, H/16, W/16, C],适配 Transformer 输入return x.permute(0, 2, 3, 1)

输入图像 [ B , 3 , H , W ] [B, 3, H, W] [B,3,H,W] 转换为 [ B , H 16 , W 16 , 768 ] [B, \frac{H}{16}, \frac{W}{16}, 768] [B,16H,16W,768]

  • Transformer Encoder:堆叠多个 Transformer Block 提取特征:
class Block(nn.Module):def __init__(self, dim, num_heads, mlp_ratio):# 初始化 Transformer Blocksuper().__init__()# 第一层归一化self.norm1 = nn.LayerNorm(dim)# 多头注意力模块self.attn = Attention(dim, num_heads)# 第二层归一化self.norm2 = nn.LayerNorm(dim)# 多层感知机,隐藏层维度为 dim * mlp_ratioself.mlp = Mlp(dim, int(dim * mlp_ratio))def forward(self, x):# 自注意力计算并残差连接x = x + self.attn(self.norm1(x))# MLP 计算并残差连接x = x + self.mlp(self.norm2(x))return x

(二)Prompt Encoder 源码解析

编码不同类型提示:

class PromptEncoder(nn.Module):def __init__(self, embed_dim, image_embedding_size, input_image_size):# 初始化 Prompt Encoder 模块super().__init__()# 位置编码层,使用随机高斯矩阵生成self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)# 提示嵌入层,处理点和框提示self.prompt_embed_layer = PromptEmbedding(embed_dim, input_image_size, image_embedding_size)def forward(self, points=None, boxes=None, masks=None):# 初始化稀疏嵌入张量,维度为 [batch_size, 0, embed_dim]sparse_embed = torch.zeros(bs, 0, self.embed_dim, device=self.device)if points:# 提取点提示的坐标和标签coords, labels = points# 生成点嵌入并拼接到稀疏嵌入中sparse_embed = torch.cat([sparse_embed, self.prompt_embed_layer.point_embedding(coords, labels)], dim=1)# 返回稀疏嵌入和密集嵌入(未完全展示 masks 处理部分)return sparse_embed, dense_embed

五、实战演练

(一)环境搭建与配置

  1. 安装 Python:版本 ≥ 3.8。
  2. 安装 PyTorch
    pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 -f https://download.pytorch.org/whl/cu117
    
  3. 安装 SAM
    pip install -U "git+https://github.com/facebookresearch/segment-anything.git"
    

(二)代码运行与结果分析

示例代码:

from segment_anything import sam_model_registry, SamPredictor
import cv2
import numpy as np
import matplotlib.pyplot as plt# 加载 vit_b 版本的 SAM 模型并移动到 GPU
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth").to("cuda")
# 初始化预测器
predictor = SamPredictor(sam)
# 读取图像并转换为 RGB 格式
image = cv2.cvtColor(cv2.imread("test_image.jpg"), cv2.COLOR_BGR2RGB)
# 设置输入图像,进行预处理和特征提取
predictor.set_image(image)
# 定义点提示坐标和标签(1 表示前景)
masks, scores, _ = predictor.predict(point_coords=np.array([[500, 375]]), point_labels=np.array([1]), multimask_output=True)# 遍历掩码和分数,显示结果
for i, (mask, score) in enumerate(zip(masks, scores)):plt.imshow(image)  # 显示原始图像plt.imshow(mask, alpha=0.6)  # 显示掩码,透明度为 0.6plt.title(f"Mask {i+1}, Score: {score:.3f}")  # 设置标题,显示掩码编号和置信度plt.show()  # 显示图像

结果分析:SAM 根据点提示生成多个掩码,分数反映置信度。高分掩码通常更准确,低分掩码可能包含错误分割。
在这里插入图片描述

六、总结与展望

SAM 通过高效的图像编码、提示编码和掩码解码实现了灵活的图像分割。未来,其在医学影像、自动驾驶和视频处理中的应用潜力巨大。技术发展方向包括模型轻量化、多模态融合和领域适应,为计算机视觉带来更多可能性。


延伸阅读



http://www.ppmy.cn/devtools/167186.html

相关文章

微信小程序面试内容整理-生命周期函数

1. 小程序生命周期函数 这些生命周期函数是在整个小程序启动、显示、隐藏、崩溃等过程中调用的。它们控制小程序的全局行为。 ● onLaunch(options) 小程序初始化时调用一次。通常在此函数中进行小程序的初始化操作,如获取用户信息、初始化设置等。

Kubernetes学习笔记-移除Nacos迁移至K8s

项目服务的配置管理和服务注册发现由原先的Nacos全面迁移到Kubernetes上。 一、移除Nacos 移除Nacos组件依赖。 <dependency><groupId>com.alibaba.cloud</groupId><artifactId>spring-cloud-starter-alibaba-nacos-discovery</artifactId> <…

使用RestTemplate发送https请求-忽略ssl证书

RestTemplate调用https服务的时候&#xff0c;由于服务方的ssl证书并非正式证书&#xff0c;不被jdk接受&#xff0c;故会报类似&#xff1a;“No subject alternative names matching IP address xxxxxx found”的错误。网上找了一下&#xff0c;处理也比较简单&#xff0c;基…

使用Mermaid语法绘制的C语言程序从Linux移植到Windows的流程图

以下是使用Mermaid语法绘制的C语言程序从Linux移植到Windows的流程图&#xff1a; graph TDA[开始移植] --> B[代码兼容性检查]B --> C[检查系统调用差异\nfork/exec -> CreateProcess]B --> D[检查文件路径格式\n/ vs \\]B --> E[检查依赖库兼容性\nPOSIX vs …

深度学习-145-Text2SQL之基于官方提示词模板进行交互

文章目录 1 基于sqlite1.1 数据库Chinook1.1.1 创建并载入数据1.1.2 SQLDatabase1.2 数据库中的表1.2.1 获取表的字段1.2.2 翻译字段1.3 建表语句2 操作单表2.1 大语言模型2.2 数据库连接2.3 官方提示词模板2.3.1 一般输出2.3.2 结构化输出2.4 执行SQL查询2.5 大模型整理结果2.…

什么是nginx的强缓存和协商缓存

一、强缓存&#xff08;Strong Cache&#xff09; 1. 定义 • 强缓存直接告诉浏览器&#xff1a;在缓存过期前&#xff0c;无需与服务器通信&#xff0c;直接使用本地缓存。 • 由服务器通过响应头 Cache-Control 和 Expires 控制。 2. 响应头 • Cache-Control: max-age3600表…

Pygame实现记忆拼图游戏1

1 游戏介绍 记忆拼图游戏的英文名叫做“memory puzzle”&#xff0c;玩家通过记忆找到相同的图片&#xff0c;如图1所示。 图1 记忆拼图游戏 从图1中可以看出&#xff0c;玩家每次点击两张图片&#xff0c;如果这两个图片是相同的图案&#xff08;包括颜色和形状&#xff09;…

树莓科技(成都)集团:如何铸就第五代产业园标杆

树莓科技&#xff08;成都&#xff09;集团铸就第五代产业园标杆&#xff0c;主要体现在以下几个方面&#xff1a; 精准定位与前瞻布局 树莓科技并非盲目扩张&#xff0c;而是精准锚定数字经济发展方向。以成都为起点&#xff0c;迅速构建起全国性的园区版图&#xff0c;体现…