当前位置: 首页 > news >正文

LLM - CV 图像实例分割开源算法 SAM2(Segment Anything 2) 配置与推理 教程 (1)

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/143220597

免责声明:本文来源于个人知识与公开资料,仅用于学术交流,欢迎讨论,不支持转载。


SAM2

SAM2(Segment Anything Model 2) 视觉分割算法是 计算机视觉(CV) 的关键技术,精确的将图像中的不同对象区分开来,使用深度学习模型来分析图像中的像素分布,生成 Mask,标识出每个对象的边界。通过在多层裁剪和不同尺度下计算 Mask 的稳定性评分,确保结果的高精度和稳定性。

Paper: Segment Anything in Images and Videos

1. 环境配置

运行代码:

git clone https://github.com/facebookresearch/sam2.git

注意:项目文件比较大,可以直接使用 zip 包,或者使用 GitHub 代理。

构建环境:

conda create -n sam2 python=3.10
conda activate sam2

安装 PyTorch 包:

pip3 install torch torchvision torchaudiopythonimport torch
print(torch.__version__)  # 2.5.0+cu124
print(torch.cuda.is_available())  # True
exit()

环境依赖:Python ≥ 3.10PyTorch ≥ 2.3.1

配置 CUDA 环境变量:

export CUDA_HOME=/usr/local/cuda  # change to your CUDA toolkit path
echo $CUDA_HOME

安装 SAM2 项目:

pip install --no-build-isolation -e .
pip install --no-build-isolation -e ".[notebooks]"  # 适配 Jupyter

--no-build-isolation 是禁用构建隔离,避免 CUDA 无法访问。

将 conda 导入 Jupyter 环境:

pip install ipykernel
python -m ipykernel install --user --name sam2 --display-name "sam2"

环境变量 export PYTORCH_ENABLE_MPS_FALLBACK=1,PyTorch 将会在遇到 MPS 不支持的操作时,自动切换到 CPU 处理。

2. 测试推理

导入 Python 包:

import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
# os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "4"  # 自定义使用的卡

配置 Torch 运行设备,与运行精度,即:

# select the device for computation
if torch.cuda.is_available():device = torch.device("cuda")
elif torch.backends.mps.is_available():device = torch.device("mps")
else:device = torch.device("cpu")
print(f"using device: {device}")if device.type == "cuda":# use bfloat16 for the entire notebooktorch.autocast("cuda", dtype=torch.bfloat16).__enter__()# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)cuda_major = torch.cuda.get_device_properties(0).majorprint(f"[Info] cuda_major: {cuda_major}")if cuda_major >= 8:torch.backends.cuda.matmul.allow_tf32 = Truetorch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":print("\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might ""give numerically different outputs and sometimes degraded performance on MPS. ""See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion.")

注意:

  • 启用 bfloat16 数据类型,自动混合精度计算,有助于提高模型训练的速度和效率,同时保持较高的精度。
    • torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
  • 获取第一个 CUDA 设备(通常是 GPU)的主要版本号,判断设备是否支持 Ampere 架构的 GPU (>=8 支持)
    • cuda_major = torch.cuda.get_device_properties(0).major
  • 如果检测 CUDA 设备的主要版本号 >= 8,即 Ampere 架构的 GPU,则启用 tfloat32 (TensorFloat-32),在 Ampere 设备上提高矩阵运算性能的优化技术。
    • torch.backends.cuda.matmul.allow_tf32 = True
    • torch.backends.cudnn.allow_tf32 = True

显示标注的 mask 信息 show_anns() 即:

np.random.seed(3)def show_anns(anns, borders=True):if len(anns) == 0:returnsorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)ax = plt.gca()ax.set_autoscale_on(False)img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))img[:, :, 3] = 0for ann in sorted_anns:m = ann['segmentation']color_mask = np.concatenate([np.random.random(3), [0.5]])img[m] = color_mask if borders:import cv2contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) # Try to smooth contourscontours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]cv2.drawContours(img, contours, -1, (0, 0, 1, 0.4), thickness=1) ax.imshow(img)

读取图像,显示图像:

# image = Image.open('notebooks/images/cars.jpg')
image = Image.open('[your path]/llm/vision_test_data/image2.png')
image = np.array(image.convert("RGB"))
# image.shape (569, 1138, 3)plt.figure(figsize=(20, 20))
plt.imshow(image)
plt.axis('on')  # 现在坐标
plt.show()

构建自动 Mask 生成器,使用默认参数,注意选择模型 sam2.1_hiera_large.pt ,以及配置参数 sam2.1_hiera_l.yaml,即:

from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGeneratorsam2_checkpoint = "sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)mask_generator = SAM2AutomaticMaskGenerator(sam2)

生成图像的 Mask:

masks = mask_generator.generate(image)
print(len(masks))
# 43
print(masks[0].keys())
# dict_keys(['segmentation', 'area', 'bbox', 'predicted_iou', 'point_coords', 'stability_score', 'crop_box'])

图像预测效果:

plt.figure(figsize=(20, 20))
plt.imshow(image)
show_anns(masks)
plt.axis('on')
plt.show() 

默认效果:

Seg

自定义生成器:

mask_generator_2 = SAM2AutomaticMaskGenerator(model=sam2,points_per_side=64,points_per_batch=128,pred_iou_thresh=0.7,stability_score_thresh=0.92,stability_score_offset=0.7,crop_n_layers=1,box_nms_thresh=0.7,crop_n_points_downscale_factor=2,min_mask_region_area=25.0,use_m2m=True,
)

参数说明:

  1. model (Sam):用于生成 mask 预测的 SAM2 模型。
  2. points_per_side (int or None):沿图像一侧采样的点数。总点数为 points_per_side**2。如果为 None,则需要 point_grids 提供显式点采样。
  3. points_per_batch (int):模型同时处理的点数。点数越多,速度越快,占用更多 GPU 内存。
  4. pred_iou_thresh (float):使用模型预测的 mask 质量的过滤阈值,范围在 [0,1]
  5. stability_score_thresh (float):使用 mask 在二值化过程中,变化的稳定性作为过滤阈值,范围在 [0,1]
  6. stability_score_offset (float):计算稳定性评分时用于调整 mask 的偏移量。
  7. box_nms_thresh (float):非极大值抑制中使用的 Box IoU 阈值,用于过滤重复 mask。
  8. crop_n_layers (int):如果 >0,会在图像裁剪后,再次运行 mask 预测。设置要运行的层数,每层有 2**i_layer 个图像裁剪。
  9. crop_n_points_downscale_factor (int):第 n 层采样的每边点数按 crop_n_points_downscale_factor**n 缩放。
  10. min_mask_region_area (int):如果 >0,后处理将移除面积小于 min_mask_region_area 的分离区域和 mask 中的孔洞。需要 OpenCV。
  11. use_m2m (bool):是否使用以前的 mask 预测进行一步优化,即在 mask 中,继续进行分割 mask。

运行图像分割:

masks2 = mask_generator_2.generate(image)
plt.figure(figsize=(20, 20))
plt.imshow(image)
show_anns(masks2)
plt.axis('off')
plt.show() 

示例图像分割,更加细腻,即:

Seg


http://www.mrgr.cn/news/58293.html

相关文章:

  • 数据治理:确保数据资产健康的关键策略
  • 跨界创新|使用自定义YOLOv11和Ollama(Llama 3)增强OCR文本识别
  • 基于Python和OpenCV的疲劳检测系统设计与实现
  • 集群系统盘损坏后的服务恢复
  • 【C++智能指针深度解析】std::shared_ptr、std::unique_ptr与std::weak_ptr的构造、原理及应用实战
  • 1.5 ROS架构
  • 力扣之612.平面上的最近距离
  • softmax回归从零实现
  • 一文学会LLM参数量计算
  • qt中qjson存储的是string类型的数据时,对于““和null的区别
  • echarts 矩阵树图treemap
  • 当遇到 502 错误(Bad Gateway)怎么办
  • HarmonyOS 5.0应用开发——Navigation实现页面路由
  • 光谱指标-预测含水量-多种特征提取方式
  • 【数据结构和算法】一、算法复杂度:时间复杂度和空间复杂度)
  • Electron 是一个用于构建跨平台桌面应用程序的开源框架
  • Docker:容器化的革命
  • 【EndNote使用教程】创建文献库、导入文献、文献分类
  • DAY62WEB 攻防-PHP 反序列化CLI 框架类PHPGGC 生成器TPYiiLaravel 等利用
  • 设备管理智能化:中小企业的Spring Boot系统
  • 介绍一款Java开发的企业接口管理系统和开放平台
  • 27.8 把target做一致性哈希进行分发
  • 双十一电容笔选哪个好?!西圣、益博思、吉玛仕电容笔实测对比!
  • 区块链行业低迷的原因及未来发展展望
  • 【贪心算法】(第十四篇)
  • 落实安全左移迫在眉睫 | 伊朗APT34组织针对阿联酋及海湾关键基础设施发动攻击