基于SAM大模型的遥感影像分割工具,用于创建交互式标注、识别地物的能力,可利用Flask进行封装作为Web后台服务

news/2024/11/14 20:48:51/

如有帮助,支持一下(GitHub - Lvbta/ImageSegmentationTool-SAM: An interactive annotation case developed based on SAM for remote sensing image annotation, which can generate corresponding segmentation results based on point, multi-point, and rectangular box prompts, and convert the recognition results into vector data shp.)

本项目提供了一个图像分割工具,利用 Segment Anything Model (SAM) 对大规模的卫星或航拍图像进行分割。该工具支持通过单点、多点或边界框输入进行图像分割,并将分割结果保存为 shapefile,以便进一步进行地理空间分析。

功能特点

  • 单点分割:支持基于单个点的输入进行分割。
  • 多点分割:支持使用多个点进行分割。
  • 边界框分割:支持在指定的边界框内进行分割。
  • 地理空间集成:使用 GDAL 读取地理空间图像,并将分割的掩膜转换为多边形。
  • Shapefile 导出:将分割结果保存为 shapefile,方便与 GIS 工具集成。
  • 可视化:在原始图像上可视化分割结果,便于验证和分析。

安装

  1. 克隆仓库:

    git clone https://github.com/Lvbta/ImageSegmentationTool.git
    cd ImageSegmentationTool
  2. 下载SAM权重:

    • default or vit_h: ViT-H SAM model.
    • vit_l: ViT-L SAM model.
    • vit_b: ViT-B SAM model.
  3. 安装所需的依赖:

    pip install -r requirements.txt
  4. 设置环境变量:

    • 代码内已设置 KMP_DUPLICATE_LIB_OK 变量,以避免冲突。

使用方法

步骤 1:准备数据

  • 图像:确保您拥有地理参考的卫星或航拍图像,格式为 TIFF。
  • SAM 模型检查点:下载 SAM 模型检查点文件,并将其放置在项目目录中。

步骤 2:配置参数

在脚本中设置以下参数:

  • image_path: 您的地理参考图像文件的路径(例如 ./sentinel2.tif)。
  • sam_checkpoint: 您的 SAM 模型检查点文件的路径(例如 ./sam_vit_b_01ec64.pth)。
  • model_type: 用于分割的模型类型(vit_bvit_l 等)。
  • device: 用于运行模型的设备(cpu 或 cuda)。
  • output_shp: 保存输出 shapefile 的路径。

步骤 3:运行分割

选择分割模式并指定必要的输入点或边界框:

  • 单点模式

    seg_mode = 'single_point'
    input_points = [[1248, 1507]]
    single_label = [1]
  • 多点模式

    seg_mode = 'multi_point'
    input_points = [[389, 1041],[411, 1094]]
    single_label = [1, 1]
  • 边界框模式

    seg_mode = 'box'
    input_box = [[0, 951, 1909, 2383]]
    single_label = [1]

步骤 4:执行脚本

运行脚本以进行分割:

python main.py

步骤 5:可视化并保存结果

分割的掩膜将被可视化,多边形将作为 shapefile 保存到指定位置。

示例

使用边界框对图像进行分割,脚本配置如下:

# 边界框模式示例配置
seg_mode = 'box'
input_box = [[0, 951, 1909, 2383]]
single_label = [1]segmenter = ImageSegmentation(image_path, sam_checkpoint, model_type, device)
masks, scores, x_off, y_off = segmenter.predict(mode=seg_mode, input_box=input_box, input_labels=single_label, multimask_output=True)
polygons = segmenter.masks_to_polygons(masks, x_off, y_off)
segmenter.save_polygons_gdal(polygons, output_shp)
segmenter.show_masks(seg_mode, masks, scores, x_off, y_off, input_box, single_label, image_chunk)
python">import numpy as np
import torch
import cv2
import sys
from osgeo import gdal, ogr, osr
from shapely.geometry import Polygon
from shapely.wkb import dumps
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
plt.rcParams['font.sans-serif'] = 'SimHei'  # 设置中文显示
plt.rcParams['axes.unicode_minus'] = False
# plt.style.use('ggplot')class ImageSegmentation:def __init__(self, image_path, sam_checkpoint, model_type='vit_b', device='cpu'):self.image_path = image_pathself.sam_checkpoint = sam_checkpointself.model_type = model_typeself.device = deviceself.geo_transform, self.proj = self.get_geoinfo()self.sam = self.load_sam_model()self.predictor = self.init_predictor()def get_geoinfo(self):dataset = gdal.Open(self.image_path)geo_transform = dataset.GetGeoTransform()proj = dataset.GetProjection()dataset = None  # 关闭return geo_transform, projdef read_image_chunk(self, x_off, y_off, x_size, y_size):dataset = gdal.Open(self.image_path)image = dataset.ReadAsArray(x_off, y_off, x_size, y_size)dataset = None  # 关闭if len(image.shape) == 3:image = np.transpose(image, (1, 2, 0))  # GDAL reads in (bands, height, width) formatelse:image = np.stack([image] * 3, axis=-1)  # If it's a single-band image, stack to (height, width, 3)return imagedef load_sam_model(self):sys.path.append("..")from segment_anything import sam_model_registrysam = sam_model_registry[self.model_type](checkpoint=self.sam_checkpoint)sam.to(device=self.device)return samdef init_predictor(self):from segment_anything import SamPredictorpredictor = SamPredictor(self.sam)return predictordef predict(self, mode='single_point', input_points=None, input_labels=None, input_box=None, multimask_output=None):if mode == 'single_point':assert input_points is not None and input_labels is not None, "Points and labels are required for single point mode."x, y = input_points[0]chunk_size = 512  # or any appropriate sizex_off = max(x - chunk_size // 2, 0)y_off = max(y - chunk_size // 2, 0)x_size = y_size = chunk_sizeimage_chunk = self.read_image_chunk(x_off, y_off, x_size, y_size)self.predictor.set_image(image_chunk)adjusted_points = [(x - x_off, y - y_off)]masks, scores, logits = self.predictor.predict(point_coords=np.array(adjusted_points),point_labels=np.array(input_labels),multimask_output=multimask_output,)elif mode == 'multi_point':assert input_points is not None and input_labels is not None, "Points and labels are required for multi point mode."# Determine bounding box of all pointsx_min = min(p[0] for p in input_points)y_min = min(p[1] for p in input_points)x_max = max(p[0] for p in input_points)y_max = max(p[1] for p in input_points)margin = 256  # or any appropriate marginx_off = max(x_min - margin, 0)y_off = max(y_min - margin, 0)x_size = min(x_max - x_min + 2 * margin, 2048)y_size = min(y_max - y_min + 2 * margin, 2048)image_chunk = self.read_image_chunk(x_off, y_off, x_size, y_size)self.predictor.set_image(image_chunk)adjusted_points = [(x - x_off, y - y_off) for x, y in input_points]masks, scores, logits = self.predictor.predict(point_coords=np.array(adjusted_points),point_labels=np.array(input_labels),multimask_output=multimask_output,)elif mode == 'box':assert input_box is not None, "Box coordinates are required for box mode."x_min, y_min, x_max, y_max = input_box[0]margin = 256  # or any appropriate marginx_off = max(x_min - margin, 0)y_off = max(y_min - margin, 0)x_size = min(x_max - x_min + 2 * margin, 2048)y_size = min(y_max - y_min + 2 * margin, 2048)image_chunk = self.read_image_chunk(x_off, y_off, x_size, y_size)self.predictor.set_image(image_chunk)adjusted_box = [(x_min - x_off, y_min - y_off, x_max - x_off, y_max - y_off)]masks, scores, logits = self.predictor.predict(box=np.array(adjusted_box).reshape(1, -1),multimask_output=multimask_output,)else:raise ValueError("Mode must be 'single_point', 'multi_point', or 'box'.")return masks, scores, x_off, y_offdef masks_to_polygons(self, masks, x_off, y_off):polygons = []for mask in masks:contours, _ = cv2.findContours((mask > 0).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)for contour in contours:contour = contour.squeeze()if len(contour.shape) == 2 and len(contour) >= 3:  # valid polygongeo_contour = [self.pixel_to_geo(x + x_off, y + y_off) for x, y in contour]polygon = Polygon(geo_contour)if polygon.is_valid:polygons.append(polygon)return polygonsdef pixel_to_geo(self, x, y):geox = self.geo_transform[0] + x * self.geo_transform[1] + y * self.geo_transform[2]geoy = self.geo_transform[3] + x * self.geo_transform[4] + y * self.geo_transform[5]return geox, geoydef save_polygons_gdal(self, polygons, output_shp):driver = ogr.GetDriverByName("ESRI Shapefile")data_source = driver.CreateDataSource(output_shp)spatial_ref = osr.SpatialReference()spatial_ref.ImportFromWkt(self.proj)  # 使用图像的投影信息layer = data_source.CreateLayer("segmentation", spatial_ref, ogr.wkbPolygon)layer_defn = layer.GetLayerDefn()for i, polygon in enumerate(polygons):feature = ogr.Feature(layer_defn)geom_wkb = dumps(polygon)  # 将Shapely几何对象转换为WKBogr_geom = ogr.CreateGeometryFromWkb(geom_wkb)  # 从WKB创建OGR几何对象feature.SetGeometry(ogr_geom)feature.SetField("id", i + 1)layer.CreateFeature(feature)feature = Nonedata_source = Nonedef show_masks(self, mode, masks, scores,x_off, y_off, input_point, input_label, image):for i, (mask, score) in enumerate(zip(masks, scores)):plt.figure(figsize=(10, 10))plt.imshow(image)self.show_mask(mask, plt.gca())if mode == 'box':self.show_box(np.array(input_point[0]), plt.gca(), x_off, y_off)else:self.show_points(np.array(input_point), np.array(input_label), plt.gca(), x_off, y_off)plt.title(f"{mode}模式 {i + 1}, Score: {score:.3f}", fontsize=18)plt.axis('on')plt.show()def show_mask(self, mask, ax, x_off=0, y_off=0):mask_resized = np.zeros((mask.shape[0] + y_off, mask.shape[1] + x_off), dtype=np.uint8)mask_resized[y_off:y_off + mask.shape[0], x_off:x_off + mask.shape[1]] = mask.astype(np.uint8)contours, _ = cv2.findContours(mask_resized, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)for contour in contours:contour[:, :, 0] += x_offcontour[:, :, 1] += y_offax.plot(contour[:, 0, 0], contour[:, 0, 1], color='lime', linewidth=2)def show_points(self, points, labels, ax, x_off, y_off):for point, label in zip(points, labels):x, y = pointx -= x_off  y -= y_off  ax.scatter(x, y, c='red', marker='o', label=f'Label: {label}')@staticmethoddef show_box(box, ax, x_off, y_off):x0, y0 = box[0]-x_off, box[1]-y_offw, h = box[2] - box[0], box[3] - box[1]ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='red', facecolor=(0, 0, 0, 0), lw=2))if __name__ == '__main__':# Usageimage_path = r'./data/sentinel2.tif'sam_checkpoint = "./model/sam_vit_b_01ec64.pth"model_type = "vit_b"device = "cpu"output_shp = r'./result/segmentation_results.shp'# # 预测模式# seg_mode = 'single_point'# # # 模型参数# input_points = [[1248, 1507]]# single_label = [1]# # 预测模式# seg_mode = 'multi_point'# # 模型参数# input_points = [[389, 1041],[411, 1094]]# single_label = [1, 1]# 预测模式seg_mode = 'box'# 模型参数input_box = [[0, 951, 1909, 2383]]single_label = [1]# 实例化类segmenter = ImageSegmentation(image_path, sam_checkpoint, model_type, device)# # 调用segAnything模型# masks, scores, x_off, y_off = segmenter.predict(mode=seg_mode, input_points=input_points,#                                                     input_labels=single_label, multimask_output=False)# boxmasks, scores, x_off, y_off = segmenter.predict(mode=seg_mode, input_box=input_box,input_labels=single_label, multimask_output=True)# 模型预测结果转矢量多边形polygons = segmenter.masks_to_polygons(masks, x_off, y_off)# 保存为shpsegmenter.save_polygons_gdal(polygons, output_shp)# 可视化image_chunk = segmenter.read_image_chunk(x_off, y_off, 512, 512)# segmenter.show_masks(seg_mode, masks, scores, x_off, y_off, input_points, single_label, image_chunk)# boxsegmenter.show_masks(seg_mode, masks, scores, x_off, y_off, input_box, single_label, image_chunk)


http://www.ppmy.cn/news/1529256.html

相关文章

熟练的Java程序员:掌握核心技能,引领技术潮流

Java,作为一门成熟且广泛应用的编程语言,对于程序员来说,不仅是一种技能,更是一种职业态度的体现。一个熟练的Java程序员,应该具备哪些技术呢?本文将为您揭晓答案。 1. 扎实的Java基础 熟练掌握Java语言的…

Hi3559A/C V100 集成了双核 A73 和双核 A53,支持 8K30/4K120 视频录制

1.1 概述 Hi3559AV100 是专业的 8K Ultra HD Mobile Camera SOC ,它提供了 8K30/4K120 广播级图像质量的数字视频录制,支持多路 Sensor 输入,支持 H.265 编码输出或影视 级的 RAW 数据输出,并集成高性能 ISP 处理&…

传输层协议 —— TCP协议(上篇)

目录 1.认识TCP 2.TCP协议段格式 3.可靠性保证的机制 确认应答机制 超时重传机制 连接管理机制 三次握手 四次挥手 1.认识TCP 在网络通信模型中,传输层有两个经典的协议,分别是UDP协议和TCP协议。其中TCP协议全称为传输控制协议(Tra…

Python使用BeautifulSoup解析HTML并提取数据

Python使用BeautifulSoup解析HTML并提取数据 在现代网络开发中,解析HTML并提取数据是一个常见的任务。Python提供了许多强大的库来实现这一功能,其中BeautifulSoup是最受欢迎的库之一。本文将详细介绍如何使用BeautifulSoup解析HTML并提取数据,帮助读者掌握这一实用技能。 …

P1056 [NOIP2008 普及组] 排座椅(模拟)

1.用x,y数组存放切了几对学生,用数组的下标记录切的位置 2.按照题目要求k和l依次取出最大的数组的值,并将其变为-1, 再次循环取出第二大的值,之后所有下标为-1的的下标就是切的学生对多的 3.切的意思是把两个学生分开 #includ…

《The Realm of Loneliness

《The Realm of Loneliness》 Loneliness is like a shadow, quietly following. In the corner of time, be alone and quiet. When the curtain of night falls, the stars are silent. In the silence, the heart seeks peace. Walking alone in the wilderness, the wi…

每日刷题(算法)

我们N个真是太厉害了 思路: 我们先给数组排序,如果最小的元素不为1,那么肯定是吹牛的,我们拿一个变量记录前缀和,如果当前元素大于它前面所有元素的和1,那么sum1是不能到达的值。 代码: #def…

[OpenCV] 数字图像处理 C++ 学习——16直方图均衡化、直方图比较 详细讲解+附完整代码

文章目录 前言1.直方图均衡化的理论基础(1)什么是直方图(2)直方图均衡化原理(3)直方图均衡化公式 2.直方图比较理论基础(1)相关性 (Correlation)——HISTCMP_CORREL(2)卡方 (Chi-Square)——HISTCMP_CHISQR(3)十字交叉性 (Intersection) ——HISTCMP_INTERSECT(4)巴氏距离 (Bha…