|
|
|
""" |
|
Generate predictions using the Segment Anything Model (SAM). |
|
|
|
SAM is an advanced image segmentation model offering features like promptable segmentation and zero-shot performance. |
|
This module contains the implementation of the prediction logic and auxiliary utilities required to perform segmentation |
|
using SAM. It forms an integral part of the Ultralytics framework and is designed for high-performance, real-time image |
|
segmentation tasks. |
|
""" |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
import torchvision |
|
|
|
from ultralytics.data.augment import LetterBox |
|
from ultralytics.engine.predictor import BasePredictor |
|
from ultralytics.engine.results import Results |
|
from ultralytics.utils import DEFAULT_CFG, ops |
|
from ultralytics.utils.torch_utils import select_device |
|
from .amg import ( |
|
batch_iterator, |
|
batched_mask_to_box, |
|
build_all_layer_point_grids, |
|
calculate_stability_score, |
|
generate_crop_boxes, |
|
is_box_near_crop_edge, |
|
remove_small_regions, |
|
uncrop_boxes_xyxy, |
|
uncrop_masks, |
|
) |
|
from .build import build_sam |
|
|
|
|
|
class Predictor(BasePredictor): |
|
""" |
|
Predictor class for the Segment Anything Model (SAM), extending BasePredictor. |
|
|
|
The class provides an interface for model inference tailored to image segmentation tasks. |
|
With advanced architecture and promptable segmentation capabilities, it facilitates flexible and real-time |
|
mask generation. The class is capable of working with various types of prompts such as bounding boxes, |
|
points, and low-resolution masks. |
|
|
|
Attributes: |
|
cfg (dict): Configuration dictionary specifying model and task-related parameters. |
|
overrides (dict): Dictionary containing values that override the default configuration. |
|
_callbacks (dict): Dictionary of user-defined callback functions to augment behavior. |
|
args (namespace): Namespace to hold command-line arguments or other operational variables. |
|
im (torch.Tensor): Preprocessed input image tensor. |
|
features (torch.Tensor): Extracted image features used for inference. |
|
prompts (dict): Collection of various prompt types, such as bounding boxes and points. |
|
segment_all (bool): Flag to control whether to segment all objects in the image or only specified ones. |
|
""" |
|
|
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): |
|
""" |
|
Initialize the Predictor with configuration, overrides, and callbacks. |
|
|
|
The method sets up the Predictor object and applies any configuration overrides or callbacks provided. It |
|
initializes task-specific settings for SAM, such as retina_masks being set to True for optimal results. |
|
|
|
Args: |
|
cfg (dict): Configuration dictionary. |
|
overrides (dict, optional): Dictionary of values to override default configuration. |
|
_callbacks (dict, optional): Dictionary of callback functions to customize behavior. |
|
""" |
|
if overrides is None: |
|
overrides = {} |
|
overrides.update(dict(task="segment", mode="predict", imgsz=1024)) |
|
super().__init__(cfg, overrides, _callbacks) |
|
self.args.retina_masks = True |
|
self.im = None |
|
self.features = None |
|
self.prompts = {} |
|
self.segment_all = False |
|
|
|
def preprocess(self, im): |
|
""" |
|
Preprocess the input image for model inference. |
|
|
|
The method prepares the input image by applying transformations and normalization. |
|
It supports both torch.Tensor and list of np.ndarray as input formats. |
|
|
|
Args: |
|
im (torch.Tensor | List[np.ndarray]): BCHW tensor format or list of HWC numpy arrays. |
|
|
|
Returns: |
|
(torch.Tensor): The preprocessed image tensor. |
|
""" |
|
if self.im is not None: |
|
return self.im |
|
not_tensor = not isinstance(im, torch.Tensor) |
|
if not_tensor: |
|
im = np.stack(self.pre_transform(im)) |
|
im = im[..., ::-1].transpose((0, 3, 1, 2)) |
|
im = np.ascontiguousarray(im) |
|
im = torch.from_numpy(im) |
|
|
|
im = im.to(self.device) |
|
im = im.half() if self.model.fp16 else im.float() |
|
if not_tensor: |
|
im = (im - self.mean) / self.std |
|
return im |
|
|
|
def pre_transform(self, im): |
|
""" |
|
Perform initial transformations on the input image for preprocessing. |
|
|
|
The method applies transformations such as resizing to prepare the image for further preprocessing. |
|
Currently, batched inference is not supported; hence the list length should be 1. |
|
|
|
Args: |
|
im (List[np.ndarray]): List containing images in HWC numpy array format. |
|
|
|
Returns: |
|
(List[np.ndarray]): List of transformed images. |
|
""" |
|
assert len(im) == 1, "SAM model does not currently support batched inference" |
|
letterbox = LetterBox(self.args.imgsz, auto=False, center=False) |
|
return [letterbox(image=x) for x in im] |
|
|
|
def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs): |
|
""" |
|
Perform image segmentation inference based on the given input cues, using the currently loaded image. This |
|
method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and |
|
mask decoder for real-time and promptable segmentation tasks. |
|
|
|
Args: |
|
im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W). |
|
bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format. |
|
points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixel coordinates. |
|
labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 for foreground and 0 for background. |
|
masks (np.ndarray, optional): Low-resolution masks from previous predictions. Shape should be (N, H, W). For SAM, H=W=256. |
|
multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts. Defaults to False. |
|
|
|
Returns: |
|
(tuple): Contains the following three elements. |
|
- np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks. |
|
- np.ndarray: An array of length C containing quality scores predicted by the model for each mask. |
|
- np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256. |
|
""" |
|
|
|
bboxes = self.prompts.pop("bboxes", bboxes) |
|
points = self.prompts.pop("points", points) |
|
masks = self.prompts.pop("masks", masks) |
|
|
|
if all(i is None for i in [bboxes, points, masks]): |
|
return self.generate(im, *args, **kwargs) |
|
|
|
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output) |
|
|
|
def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False): |
|
""" |
|
Internal function for image segmentation inference based on cues like bounding boxes, points, and masks. |
|
Leverages SAM's specialized architecture for prompt-based, real-time segmentation. |
|
|
|
Args: |
|
im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W). |
|
bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format. |
|
points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixel coordinates. |
|
labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 for foreground and 0 for background. |
|
masks (np.ndarray, optional): Low-resolution masks from previous predictions. Shape should be (N, H, W). For SAM, H=W=256. |
|
multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts. Defaults to False. |
|
|
|
Returns: |
|
(tuple): Contains the following three elements. |
|
- np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks. |
|
- np.ndarray: An array of length C containing quality scores predicted by the model for each mask. |
|
- np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256. |
|
""" |
|
features = self.model.image_encoder(im) if self.features is None else self.features |
|
|
|
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:] |
|
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1]) |
|
|
|
if points is not None: |
|
points = torch.as_tensor(points, dtype=torch.float32, device=self.device) |
|
points = points[None] if points.ndim == 1 else points |
|
|
|
if labels is None: |
|
labels = np.ones(points.shape[0]) |
|
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device) |
|
points *= r |
|
|
|
points, labels = points[:, None, :], labels[:, None] |
|
if bboxes is not None: |
|
bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device) |
|
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes |
|
bboxes *= r |
|
if masks is not None: |
|
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1) |
|
|
|
points = (points, labels) if points is not None else None |
|
|
|
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks) |
|
|
|
|
|
pred_masks, pred_scores = self.model.mask_decoder( |
|
image_embeddings=features, |
|
image_pe=self.model.prompt_encoder.get_dense_pe(), |
|
sparse_prompt_embeddings=sparse_embeddings, |
|
dense_prompt_embeddings=dense_embeddings, |
|
multimask_output=multimask_output, |
|
) |
|
|
|
|
|
|
|
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1) |
|
|
|
def generate( |
|
self, |
|
im, |
|
crop_n_layers=0, |
|
crop_overlap_ratio=512 / 1500, |
|
crop_downscale_factor=1, |
|
point_grids=None, |
|
points_stride=32, |
|
points_batch_size=64, |
|
conf_thres=0.88, |
|
stability_score_thresh=0.95, |
|
stability_score_offset=0.95, |
|
crop_nms_thresh=0.7, |
|
): |
|
""" |
|
Perform image segmentation using the Segment Anything Model (SAM). |
|
|
|
This function segments an entire image into constituent parts by leveraging SAM's advanced architecture |
|
and real-time performance capabilities. It can optionally work on image crops for finer segmentation. |
|
|
|
Args: |
|
im (torch.Tensor): Input tensor representing the preprocessed image with dimensions (N, C, H, W). |
|
crop_n_layers (int): Specifies the number of layers for additional mask predictions on image crops. |
|
Each layer produces 2**i_layer number of image crops. |
|
crop_overlap_ratio (float): Determines the extent of overlap between crops. Scaled down in subsequent layers. |
|
crop_downscale_factor (int): Scaling factor for the number of sampled points-per-side in each layer. |
|
point_grids (list[np.ndarray], optional): Custom grids for point sampling normalized to [0,1]. |
|
Used in the nth crop layer. |
|
points_stride (int, optional): Number of points to sample along each side of the image. |
|
Exclusive with 'point_grids'. |
|
points_batch_size (int): Batch size for the number of points processed simultaneously. |
|
conf_thres (float): Confidence threshold [0,1] for filtering based on the model's mask quality prediction. |
|
stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on mask stability. |
|
stability_score_offset (float): Offset value for calculating stability score. |
|
crop_nms_thresh (float): IoU cutoff for Non-Maximum Suppression (NMS) to remove duplicate masks between crops. |
|
|
|
Returns: |
|
(tuple): A tuple containing segmented masks, confidence scores, and bounding boxes. |
|
""" |
|
self.segment_all = True |
|
ih, iw = im.shape[2:] |
|
crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio) |
|
if point_grids is None: |
|
point_grids = build_all_layer_point_grids(points_stride, crop_n_layers, crop_downscale_factor) |
|
pred_masks, pred_scores, pred_bboxes, region_areas = [], [], [], [] |
|
for crop_region, layer_idx in zip(crop_regions, layer_idxs): |
|
x1, y1, x2, y2 = crop_region |
|
w, h = x2 - x1, y2 - y1 |
|
area = torch.tensor(w * h, device=im.device) |
|
points_scale = np.array([[w, h]]) |
|
|
|
crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode="bilinear", align_corners=False) |
|
|
|
points_for_image = point_grids[layer_idx] * points_scale |
|
crop_masks, crop_scores, crop_bboxes = [], [], [] |
|
for (points,) in batch_iterator(points_batch_size, points_for_image): |
|
pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True) |
|
|
|
pred_mask = F.interpolate(pred_mask[None], (h, w), mode="bilinear", align_corners=False)[0] |
|
idx = pred_score > conf_thres |
|
pred_mask, pred_score = pred_mask[idx], pred_score[idx] |
|
|
|
stability_score = calculate_stability_score( |
|
pred_mask, self.model.mask_threshold, stability_score_offset |
|
) |
|
idx = stability_score > stability_score_thresh |
|
pred_mask, pred_score = pred_mask[idx], pred_score[idx] |
|
|
|
pred_mask = pred_mask > self.model.mask_threshold |
|
|
|
pred_bbox = batched_mask_to_box(pred_mask).float() |
|
keep_mask = ~is_box_near_crop_edge(pred_bbox, crop_region, [0, 0, iw, ih]) |
|
if not torch.all(keep_mask): |
|
pred_bbox, pred_mask, pred_score = pred_bbox[keep_mask], pred_mask[keep_mask], pred_score[keep_mask] |
|
|
|
crop_masks.append(pred_mask) |
|
crop_bboxes.append(pred_bbox) |
|
crop_scores.append(pred_score) |
|
|
|
|
|
crop_masks = torch.cat(crop_masks) |
|
crop_bboxes = torch.cat(crop_bboxes) |
|
crop_scores = torch.cat(crop_scores) |
|
keep = torchvision.ops.nms(crop_bboxes, crop_scores, self.args.iou) |
|
crop_bboxes = uncrop_boxes_xyxy(crop_bboxes[keep], crop_region) |
|
crop_masks = uncrop_masks(crop_masks[keep], crop_region, ih, iw) |
|
crop_scores = crop_scores[keep] |
|
|
|
pred_masks.append(crop_masks) |
|
pred_bboxes.append(crop_bboxes) |
|
pred_scores.append(crop_scores) |
|
region_areas.append(area.expand(len(crop_masks))) |
|
|
|
pred_masks = torch.cat(pred_masks) |
|
pred_bboxes = torch.cat(pred_bboxes) |
|
pred_scores = torch.cat(pred_scores) |
|
region_areas = torch.cat(region_areas) |
|
|
|
|
|
if len(crop_regions) > 1: |
|
scores = 1 / region_areas |
|
keep = torchvision.ops.nms(pred_bboxes, scores, crop_nms_thresh) |
|
pred_masks, pred_bboxes, pred_scores = pred_masks[keep], pred_bboxes[keep], pred_scores[keep] |
|
|
|
return pred_masks, pred_scores, pred_bboxes |
|
|
|
def setup_model(self, model, verbose=True): |
|
""" |
|
Initializes the Segment Anything Model (SAM) for inference. |
|
|
|
This method sets up the SAM model by allocating it to the appropriate device and initializing the necessary |
|
parameters for image normalization and other Ultralytics compatibility settings. |
|
|
|
Args: |
|
model (torch.nn.Module): A pre-trained SAM model. If None, a model will be built based on configuration. |
|
verbose (bool): If True, prints selected device information. |
|
|
|
Attributes: |
|
model (torch.nn.Module): The SAM model allocated to the chosen device for inference. |
|
device (torch.device): The device to which the model and tensors are allocated. |
|
mean (torch.Tensor): The mean values for image normalization. |
|
std (torch.Tensor): The standard deviation values for image normalization. |
|
""" |
|
device = select_device(self.args.device, verbose=verbose) |
|
if model is None: |
|
model = build_sam(self.args.model) |
|
model.eval() |
|
self.model = model.to(device) |
|
self.device = device |
|
self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device) |
|
self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device) |
|
|
|
|
|
self.model.pt = False |
|
self.model.triton = False |
|
self.model.stride = 32 |
|
self.model.fp16 = False |
|
self.done_warmup = True |
|
|
|
def postprocess(self, preds, img, orig_imgs): |
|
""" |
|
Post-processes SAM's inference outputs to generate object detection masks and bounding boxes. |
|
|
|
The method scales masks and boxes to the original image size and applies a threshold to the mask predictions. The |
|
SAM model uses advanced architecture and promptable segmentation tasks to achieve real-time performance. |
|
|
|
Args: |
|
preds (tuple): The output from SAM model inference, containing masks, scores, and optional bounding boxes. |
|
img (torch.Tensor): The processed input image tensor. |
|
orig_imgs (list | torch.Tensor): The original, unprocessed images. |
|
|
|
Returns: |
|
(list): List of Results objects containing detection masks, bounding boxes, and other metadata. |
|
""" |
|
|
|
pred_masks, pred_scores = preds[:2] |
|
pred_bboxes = preds[2] if self.segment_all else None |
|
names = dict(enumerate(str(i) for i in range(len(pred_masks)))) |
|
|
|
if not isinstance(orig_imgs, list): |
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) |
|
|
|
results = [] |
|
for i, masks in enumerate([pred_masks]): |
|
orig_img = orig_imgs[i] |
|
if pred_bboxes is not None: |
|
pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False) |
|
cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device) |
|
pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1) |
|
|
|
masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0] |
|
masks = masks > self.model.mask_threshold |
|
img_path = self.batch[0][i] |
|
results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes)) |
|
|
|
self.segment_all = False |
|
return results |
|
|
|
def setup_source(self, source): |
|
""" |
|
Sets up the data source for inference. |
|
|
|
This method configures the data source from which images will be fetched for inference. The source could be a |
|
directory, a video file, or other types of image data sources. |
|
|
|
Args: |
|
source (str | Path): The path to the image data source for inference. |
|
""" |
|
if source is not None: |
|
super().setup_source(source) |
|
|
|
def set_image(self, image): |
|
""" |
|
Preprocesses and sets a single image for inference. |
|
|
|
This function sets up the model if not already initialized, configures the data source to the specified image, |
|
and preprocesses the image for feature extraction. Only one image can be set at a time. |
|
|
|
Args: |
|
image (str | np.ndarray): Image file path as a string, or a np.ndarray image read by cv2. |
|
|
|
Raises: |
|
AssertionError: If more than one image is set. |
|
""" |
|
if self.model is None: |
|
model = build_sam(self.args.model) |
|
self.setup_model(model) |
|
self.setup_source(image) |
|
assert len(self.dataset) == 1, "`set_image` only supports setting one image!" |
|
for batch in self.dataset: |
|
im = self.preprocess(batch[1]) |
|
self.features = self.model.image_encoder(im) |
|
self.im = im |
|
break |
|
|
|
def set_prompts(self, prompts): |
|
"""Set prompts in advance.""" |
|
self.prompts = prompts |
|
|
|
def reset_image(self): |
|
"""Resets the image and its features to None.""" |
|
self.im = None |
|
self.features = None |
|
|
|
@staticmethod |
|
def remove_small_regions(masks, min_area=0, nms_thresh=0.7): |
|
""" |
|
Perform post-processing on segmentation masks generated by the Segment Anything Model (SAM). Specifically, this |
|
function removes small disconnected regions and holes from the input masks, and then performs Non-Maximum |
|
Suppression (NMS) to eliminate any newly created duplicate boxes. |
|
|
|
Args: |
|
masks (torch.Tensor): A tensor containing the masks to be processed. Shape should be (N, H, W), where N is |
|
the number of masks, H is height, and W is width. |
|
min_area (int): The minimum area below which disconnected regions and holes will be removed. Defaults to 0. |
|
nms_thresh (float): The IoU threshold for the NMS algorithm. Defaults to 0.7. |
|
|
|
Returns: |
|
(tuple([torch.Tensor, List[int]])): |
|
- new_masks (torch.Tensor): The processed masks with small regions removed. Shape is (N, H, W). |
|
- keep (List[int]): The indices of the remaining masks post-NMS, which can be used to filter the boxes. |
|
""" |
|
if len(masks) == 0: |
|
return masks |
|
|
|
|
|
new_masks = [] |
|
scores = [] |
|
for mask in masks: |
|
mask = mask.cpu().numpy().astype(np.uint8) |
|
mask, changed = remove_small_regions(mask, min_area, mode="holes") |
|
unchanged = not changed |
|
mask, changed = remove_small_regions(mask, min_area, mode="islands") |
|
unchanged = unchanged and not changed |
|
|
|
new_masks.append(torch.as_tensor(mask).unsqueeze(0)) |
|
|
|
scores.append(float(unchanged)) |
|
|
|
|
|
new_masks = torch.cat(new_masks, dim=0) |
|
boxes = batched_mask_to_box(new_masks) |
|
keep = torchvision.ops.nms(boxes.float(), torch.as_tensor(scores), nms_thresh) |
|
|
|
return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep |
|
|