|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
from typing import List, Optional
|
|
|
|
from segment_anything import SamAutomaticMaskGenerator
|
|
from segment_anything.utils.amg import build_all_layer_point_grids
|
|
from .predictor import SamPredictorHQ
|
|
|
|
|
|
class SamAutomaticMaskGeneratorHQ(SamAutomaticMaskGenerator):
|
|
def __init__(
|
|
self,
|
|
model: SamPredictorHQ,
|
|
points_per_side: Optional[int] = 32,
|
|
points_per_batch: int = 64,
|
|
pred_iou_thresh: float = 0.88,
|
|
stability_score_thresh: float = 0.95,
|
|
stability_score_offset: float = 1.0,
|
|
box_nms_thresh: float = 0.7,
|
|
crop_n_layers: int = 0,
|
|
crop_nms_thresh: float = 0.7,
|
|
crop_overlap_ratio: float = 512 / 1500,
|
|
crop_n_points_downscale_factor: int = 1,
|
|
point_grids: Optional[List[np.ndarray]] = None,
|
|
min_mask_region_area: int = 0,
|
|
output_mode: str = "binary_mask",
|
|
) -> None:
|
|
"""
|
|
Using a SAM model, generates masks for the entire image.
|
|
Generates a grid of point prompts over the image, then filters
|
|
low quality and duplicate masks. The default settings are chosen
|
|
for SAM with a ViT-H backbone.
|
|
|
|
Arguments:
|
|
model (Sam): The SAM model to use for mask prediction.
|
|
points_per_side (int or None): The number of points to be sampled
|
|
along one side of the image. The total number of points is
|
|
points_per_side**2. If None, 'point_grids' must provide explicit
|
|
point sampling.
|
|
points_per_batch (int): Sets the number of points run simultaneously
|
|
by the model. Higher numbers may be faster but use more GPU memory.
|
|
pred_iou_thresh (float): A filtering threshold in [0,1], using the
|
|
model's predicted mask quality.
|
|
stability_score_thresh (float): A filtering threshold in [0,1], using
|
|
the stability of the mask under changes to the cutoff used to binarize
|
|
the model's mask predictions.
|
|
stability_score_offset (float): The amount to shift the cutoff when
|
|
calculated the stability score.
|
|
box_nms_thresh (float): The box IoU cutoff used by non-maximal
|
|
suppression to filter duplicate masks.
|
|
crop_n_layers (int): If >0, mask prediction will be run again on
|
|
crops of the image. Sets the number of layers to run, where each
|
|
layer has 2**i_layer number of image crops.
|
|
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
|
|
suppression to filter duplicate masks between different crops.
|
|
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
|
In the first crop layer, crops will overlap by this fraction of
|
|
the image length. Later layers with more crops scale down this overlap.
|
|
crop_n_points_downscale_factor (int): The number of points-per-side
|
|
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
|
point_grids (list(np.ndarray) or None): A list over explicit grids
|
|
of points used for sampling, normalized to [0,1]. The nth grid in the
|
|
list is used in the nth crop layer. Exclusive with points_per_side.
|
|
min_mask_region_area (int): If >0, postprocessing will be applied
|
|
to remove disconnected regions and holes in masks with area smaller
|
|
than min_mask_region_area. Requires opencv.
|
|
output_mode (str): The form masks are returned in. Can be 'binary_mask',
|
|
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
|
|
For large resolutions, 'binary_mask' may consume large amounts of
|
|
memory.
|
|
"""
|
|
|
|
assert (points_per_side is None) != (
|
|
point_grids is None
|
|
), "Exactly one of points_per_side or point_grid must be provided."
|
|
if points_per_side is not None:
|
|
self.point_grids = build_all_layer_point_grids(
|
|
points_per_side,
|
|
crop_n_layers,
|
|
crop_n_points_downscale_factor,
|
|
)
|
|
elif point_grids is not None:
|
|
self.point_grids = point_grids
|
|
else:
|
|
raise ValueError("Can't have both points_per_side and point_grid be None.")
|
|
|
|
assert output_mode in [
|
|
"binary_mask",
|
|
"uncompressed_rle",
|
|
"coco_rle",
|
|
], f"Unknown output_mode {output_mode}."
|
|
if output_mode == "coco_rle":
|
|
from pycocotools import mask as mask_utils
|
|
|
|
if min_mask_region_area > 0:
|
|
import cv2
|
|
|
|
self.predictor = model
|
|
self.points_per_batch = points_per_batch
|
|
self.pred_iou_thresh = pred_iou_thresh
|
|
self.stability_score_thresh = stability_score_thresh
|
|
self.stability_score_offset = stability_score_offset
|
|
self.box_nms_thresh = box_nms_thresh
|
|
self.crop_n_layers = crop_n_layers
|
|
self.crop_nms_thresh = crop_nms_thresh
|
|
self.crop_overlap_ratio = crop_overlap_ratio
|
|
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
|
|
self.min_mask_region_area = min_mask_region_area
|
|
self.output_mode = output_mode
|
|
|