Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import numpy as np | |
from typing import List, Optional | |
from segment_anything import SamAutomaticMaskGenerator | |
from segment_anything.utils.amg import build_all_layer_point_grids | |
from .predictor import SamPredictorHQ | |
class SamAutomaticMaskGeneratorHQ(SamAutomaticMaskGenerator): | |
def __init__( | |
self, | |
model: SamPredictorHQ, | |
points_per_side: Optional[int] = 32, | |
points_per_batch: int = 64, | |
pred_iou_thresh: float = 0.88, | |
stability_score_thresh: float = 0.95, | |
stability_score_offset: float = 1.0, | |
box_nms_thresh: float = 0.7, | |
crop_n_layers: int = 0, | |
crop_nms_thresh: float = 0.7, | |
crop_overlap_ratio: float = 512 / 1500, | |
crop_n_points_downscale_factor: int = 1, | |
point_grids: Optional[List[np.ndarray]] = None, | |
min_mask_region_area: int = 0, | |
output_mode: str = "binary_mask", | |
) -> None: | |
""" | |
Using a SAM model, generates masks for the entire image. | |
Generates a grid of point prompts over the image, then filters | |
low quality and duplicate masks. The default settings are chosen | |
for SAM with a ViT-H backbone. | |
Arguments: | |
model (Sam): The SAM model to use for mask prediction. | |
points_per_side (int or None): The number of points to be sampled | |
along one side of the image. The total number of points is | |
points_per_side**2. If None, 'point_grids' must provide explicit | |
point sampling. | |
points_per_batch (int): Sets the number of points run simultaneously | |
by the model. Higher numbers may be faster but use more GPU memory. | |
pred_iou_thresh (float): A filtering threshold in [0,1], using the | |
model's predicted mask quality. | |
stability_score_thresh (float): A filtering threshold in [0,1], using | |
the stability of the mask under changes to the cutoff used to binarize | |
the model's mask predictions. | |
stability_score_offset (float): The amount to shift the cutoff when | |
calculated the stability score. | |
box_nms_thresh (float): The box IoU cutoff used by non-maximal | |
suppression to filter duplicate masks. | |
crop_n_layers (int): If >0, mask prediction will be run again on | |
crops of the image. Sets the number of layers to run, where each | |
layer has 2**i_layer number of image crops. | |
crop_nms_thresh (float): The box IoU cutoff used by non-maximal | |
suppression to filter duplicate masks between different crops. | |
crop_overlap_ratio (float): Sets the degree to which crops overlap. | |
In the first crop layer, crops will overlap by this fraction of | |
the image length. Later layers with more crops scale down this overlap. | |
crop_n_points_downscale_factor (int): The number of points-per-side | |
sampled in layer n is scaled down by crop_n_points_downscale_factor**n. | |
point_grids (list(np.ndarray) or None): A list over explicit grids | |
of points used for sampling, normalized to [0,1]. The nth grid in the | |
list is used in the nth crop layer. Exclusive with points_per_side. | |
min_mask_region_area (int): If >0, postprocessing will be applied | |
to remove disconnected regions and holes in masks with area smaller | |
than min_mask_region_area. Requires opencv. | |
output_mode (str): The form masks are returned in. Can be 'binary_mask', | |
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. | |
For large resolutions, 'binary_mask' may consume large amounts of | |
memory. | |
""" | |
assert (points_per_side is None) != ( | |
point_grids is None | |
), "Exactly one of points_per_side or point_grid must be provided." | |
if points_per_side is not None: | |
self.point_grids = build_all_layer_point_grids( | |
points_per_side, | |
crop_n_layers, | |
crop_n_points_downscale_factor, | |
) | |
elif point_grids is not None: | |
self.point_grids = point_grids | |
else: | |
raise ValueError("Can't have both points_per_side and point_grid be None.") | |
assert output_mode in [ | |
"binary_mask", | |
"uncompressed_rle", | |
"coco_rle", | |
], f"Unknown output_mode {output_mode}." | |
if output_mode == "coco_rle": | |
from pycocotools import mask as mask_utils # type: ignore # noqa: F401 | |
if min_mask_region_area > 0: | |
import cv2 # type: ignore # noqa: F401 | |
self.predictor = model | |
self.points_per_batch = points_per_batch | |
self.pred_iou_thresh = pred_iou_thresh | |
self.stability_score_thresh = stability_score_thresh | |
self.stability_score_offset = stability_score_offset | |
self.box_nms_thresh = box_nms_thresh | |
self.crop_n_layers = crop_n_layers | |
self.crop_nms_thresh = crop_nms_thresh | |
self.crop_overlap_ratio = crop_overlap_ratio | |
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor | |
self.min_mask_region_area = min_mask_region_area | |
self.output_mode = output_mode | |