如有帮助,支持一下(GitHub - Lvbta/ImageSegmentationTool-SAM: An interactive annotation case developed based on SAM for remote sensing image annotation, which can generate corresponding segmentation results based on point, multi-point, and rectangular box prompts, and convert the recognition results into vector data shp.)
本项目提供了一个图像分割工具,利用 Segment Anything Model (SAM) 对大规模的卫星或航拍图像进行分割。该工具支持通过单点、多点或边界框输入进行图像分割,并将分割结果保存为 shapefile,以便进一步进行地理空间分析。
功能特点
- 单点分割:支持基于单个点的输入进行分割。
- 多点分割:支持使用多个点进行分割。
- 边界框分割:支持在指定的边界框内进行分割。
- 地理空间集成:使用 GDAL 读取地理空间图像,并将分割的掩膜转换为多边形。
- Shapefile 导出:将分割结果保存为 shapefile,方便与 GIS 工具集成。
- 可视化:在原始图像上可视化分割结果,便于验证和分析。
安装
-
克隆仓库:
git clone https://github.com/Lvbta/ImageSegmentationTool.git cd ImageSegmentationTool
-
下载SAM权重:
-
安装所需的依赖:
pip install -r requirements.txt
-
设置环境变量:
- 代码内已设置
KMP_DUPLICATE_LIB_OK
变量,以避免冲突。
- 代码内已设置
使用方法
步骤 1:准备数据
步骤 2:配置参数
在脚本中设置以下参数:
image_path
: 您的地理参考图像文件的路径(例如./sentinel2.tif
)。sam_checkpoint
: 您的 SAM 模型检查点文件的路径(例如./sam_vit_b_01ec64.pth
)。model_type
: 用于分割的模型类型(vit_b
、vit_l
等)。device
: 用于运行模型的设备(cpu
或cuda
)。output_shp
: 保存输出 shapefile 的路径。
步骤 3:运行分割
选择分割模式并指定必要的输入点或边界框:
-
单点模式:
seg_mode = 'single_point' input_points = [[1248, 1507]] single_label = [1]
-
多点模式:
seg_mode = 'multi_point' input_points = [[389, 1041],[411, 1094]] single_label = [1, 1]
-
边界框模式:
seg_mode = 'box' input_box = [[0, 951, 1909, 2383]] single_label = [1]
步骤 4:执行脚本
运行脚本以进行分割:
python main.py
步骤 5:可视化并保存结果
分割的掩膜将被可视化,多边形将作为 shapefile 保存到指定位置。
示例
使用边界框对图像进行分割,脚本配置如下:
# 边界框模式示例配置 seg_mode = 'box' input_box = [[0, 951, 1909, 2383]] single_label = [1]segmenter = ImageSegmentation(image_path, sam_checkpoint, model_type, device) masks, scores, x_off, y_off = segmenter.predict(mode=seg_mode, input_box=input_box, input_labels=single_label, multimask_output=True) polygons = segmenter.masks_to_polygons(masks, x_off, y_off) segmenter.save_polygons_gdal(polygons, output_shp) segmenter.show_masks(seg_mode, masks, scores, x_off, y_off, input_box, single_label, image_chunk)
python">import numpy as np
import torch
import cv2
import sys
from osgeo import gdal, ogr, osr
from shapely.geometry import Polygon
from shapely.wkb import dumps
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
plt.rcParams['font.sans-serif'] = 'SimHei' # 设置中文显示
plt.rcParams['axes.unicode_minus'] = False
# plt.style.use('ggplot')class ImageSegmentation:def __init__(self, image_path, sam_checkpoint, model_type='vit_b', device='cpu'):self.image_path = image_pathself.sam_checkpoint = sam_checkpointself.model_type = model_typeself.device = deviceself.geo_transform, self.proj = self.get_geoinfo()self.sam = self.load_sam_model()self.predictor = self.init_predictor()def get_geoinfo(self):dataset = gdal.Open(self.image_path)geo_transform = dataset.GetGeoTransform()proj = dataset.GetProjection()dataset = None # 关闭return geo_transform, projdef read_image_chunk(self, x_off, y_off, x_size, y_size):dataset = gdal.Open(self.image_path)image = dataset.ReadAsArray(x_off, y_off, x_size, y_size)dataset = None # 关闭if len(image.shape) == 3:image = np.transpose(image, (1, 2, 0)) # GDAL reads in (bands, height, width) formatelse:image = np.stack([image] * 3, axis=-1) # If it's a single-band image, stack to (height, width, 3)return imagedef load_sam_model(self):sys.path.append("..")from segment_anything import sam_model_registrysam = sam_model_registry[self.model_type](checkpoint=self.sam_checkpoint)sam.to(device=self.device)return samdef init_predictor(self):from segment_anything import SamPredictorpredictor = SamPredictor(self.sam)return predictordef predict(self, mode='single_point', input_points=None, input_labels=None, input_box=None, multimask_output=None):if mode == 'single_point':assert input_points is not None and input_labels is not None, "Points and labels are required for single point mode."x, y = input_points[0]chunk_size = 512 # or any appropriate sizex_off = max(x - chunk_size // 2, 0)y_off = max(y - chunk_size // 2, 0)x_size = y_size = chunk_sizeimage_chunk = self.read_image_chunk(x_off, y_off, x_size, y_size)self.predictor.set_image(image_chunk)adjusted_points = [(x - x_off, y - y_off)]masks, scores, logits = self.predictor.predict(point_coords=np.array(adjusted_points),point_labels=np.array(input_labels),multimask_output=multimask_output,)elif mode == 'multi_point':assert input_points is not None and input_labels is not None, "Points and labels are required for multi point mode."# Determine bounding box of all pointsx_min = min(p[0] for p in input_points)y_min = min(p[1] for p in input_points)x_max = max(p[0] for p in input_points)y_max = max(p[1] for p in input_points)margin = 256 # or any appropriate marginx_off = max(x_min - margin, 0)y_off = max(y_min - margin, 0)x_size = min(x_max - x_min + 2 * margin, 2048)y_size = min(y_max - y_min + 2 * margin, 2048)image_chunk = self.read_image_chunk(x_off, y_off, x_size, y_size)self.predictor.set_image(image_chunk)adjusted_points = [(x - x_off, y - y_off) for x, y in input_points]masks, scores, logits = self.predictor.predict(point_coords=np.array(adjusted_points),point_labels=np.array(input_labels),multimask_output=multimask_output,)elif mode == 'box':assert input_box is not None, "Box coordinates are required for box mode."x_min, y_min, x_max, y_max = input_box[0]margin = 256 # or any appropriate marginx_off = max(x_min - margin, 0)y_off = max(y_min - margin, 0)x_size = min(x_max - x_min + 2 * margin, 2048)y_size = min(y_max - y_min + 2 * margin, 2048)image_chunk = self.read_image_chunk(x_off, y_off, x_size, y_size)self.predictor.set_image(image_chunk)adjusted_box = [(x_min - x_off, y_min - y_off, x_max - x_off, y_max - y_off)]masks, scores, logits = self.predictor.predict(box=np.array(adjusted_box).reshape(1, -1),multimask_output=multimask_output,)else:raise ValueError("Mode must be 'single_point', 'multi_point', or 'box'.")return masks, scores, x_off, y_offdef masks_to_polygons(self, masks, x_off, y_off):polygons = []for mask in masks:contours, _ = cv2.findContours((mask > 0).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)for contour in contours:contour = contour.squeeze()if len(contour.shape) == 2 and len(contour) >= 3: # valid polygongeo_contour = [self.pixel_to_geo(x + x_off, y + y_off) for x, y in contour]polygon = Polygon(geo_contour)if polygon.is_valid:polygons.append(polygon)return polygonsdef pixel_to_geo(self, x, y):geox = self.geo_transform[0] + x * self.geo_transform[1] + y * self.geo_transform[2]geoy = self.geo_transform[3] + x * self.geo_transform[4] + y * self.geo_transform[5]return geox, geoydef save_polygons_gdal(self, polygons, output_shp):driver = ogr.GetDriverByName("ESRI Shapefile")data_source = driver.CreateDataSource(output_shp)spatial_ref = osr.SpatialReference()spatial_ref.ImportFromWkt(self.proj) # 使用图像的投影信息layer = data_source.CreateLayer("segmentation", spatial_ref, ogr.wkbPolygon)layer_defn = layer.GetLayerDefn()for i, polygon in enumerate(polygons):feature = ogr.Feature(layer_defn)geom_wkb = dumps(polygon) # 将Shapely几何对象转换为WKBogr_geom = ogr.CreateGeometryFromWkb(geom_wkb) # 从WKB创建OGR几何对象feature.SetGeometry(ogr_geom)feature.SetField("id", i + 1)layer.CreateFeature(feature)feature = Nonedata_source = Nonedef show_masks(self, mode, masks, scores,x_off, y_off, input_point, input_label, image):for i, (mask, score) in enumerate(zip(masks, scores)):plt.figure(figsize=(10, 10))plt.imshow(image)self.show_mask(mask, plt.gca())if mode == 'box':self.show_box(np.array(input_point[0]), plt.gca(), x_off, y_off)else:self.show_points(np.array(input_point), np.array(input_label), plt.gca(), x_off, y_off)plt.title(f"{mode}模式 {i + 1}, Score: {score:.3f}", fontsize=18)plt.axis('on')plt.show()def show_mask(self, mask, ax, x_off=0, y_off=0):mask_resized = np.zeros((mask.shape[0] + y_off, mask.shape[1] + x_off), dtype=np.uint8)mask_resized[y_off:y_off + mask.shape[0], x_off:x_off + mask.shape[1]] = mask.astype(np.uint8)contours, _ = cv2.findContours(mask_resized, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)for contour in contours:contour[:, :, 0] += x_offcontour[:, :, 1] += y_offax.plot(contour[:, 0, 0], contour[:, 0, 1], color='lime', linewidth=2)def show_points(self, points, labels, ax, x_off, y_off):for point, label in zip(points, labels):x, y = pointx -= x_off y -= y_off ax.scatter(x, y, c='red', marker='o', label=f'Label: {label}')@staticmethoddef show_box(box, ax, x_off, y_off):x0, y0 = box[0]-x_off, box[1]-y_offw, h = box[2] - box[0], box[3] - box[1]ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='red', facecolor=(0, 0, 0, 0), lw=2))if __name__ == '__main__':# Usageimage_path = r'./data/sentinel2.tif'sam_checkpoint = "./model/sam_vit_b_01ec64.pth"model_type = "vit_b"device = "cpu"output_shp = r'./result/segmentation_results.shp'# # 预测模式# seg_mode = 'single_point'# # # 模型参数# input_points = [[1248, 1507]]# single_label = [1]# # 预测模式# seg_mode = 'multi_point'# # 模型参数# input_points = [[389, 1041],[411, 1094]]# single_label = [1, 1]# 预测模式seg_mode = 'box'# 模型参数input_box = [[0, 951, 1909, 2383]]single_label = [1]# 实例化类segmenter = ImageSegmentation(image_path, sam_checkpoint, model_type, device)# # 调用segAnything模型# masks, scores, x_off, y_off = segmenter.predict(mode=seg_mode, input_points=input_points,# input_labels=single_label, multimask_output=False)# boxmasks, scores, x_off, y_off = segmenter.predict(mode=seg_mode, input_box=input_box,input_labels=single_label, multimask_output=True)# 模型预测结果转矢量多边形polygons = segmenter.masks_to_polygons(masks, x_off, y_off)# 保存为shpsegmenter.save_polygons_gdal(polygons, output_shp)# 可视化image_chunk = segmenter.read_image_chunk(x_off, y_off, 512, 512)# segmenter.show_masks(seg_mode, masks, scores, x_off, y_off, input_points, single_label, image_chunk)# boxsegmenter.show_masks(seg_mode, masks, scores, x_off, y_off, input_box, single_label, image_chunk)