|
from typing import Optional, Tuple |
|
import numpy as np |
|
import torch |
|
from segment_anything import SamPredictor |
|
from segment_anything.modeling import Sam |
|
|
|
|
|
class SamPredictorHQ(SamPredictor): |
|
|
|
def __init__( |
|
self, |
|
sam_model: Sam, |
|
sam_is_hq: bool = False, |
|
) -> None: |
|
""" |
|
Uses SAM to calculate the image embedding for an image, and then |
|
allow repeated, efficient mask prediction given prompts. |
|
|
|
Arguments: |
|
sam_model (Sam): The model to use for mask prediction. |
|
""" |
|
super().__init__(sam_model=sam_model) |
|
self.is_hq = sam_is_hq |
|
|
|
|
|
@torch.no_grad() |
|
def set_torch_image( |
|
self, |
|
transformed_image: torch.Tensor, |
|
original_image_size: Tuple[int, ...], |
|
) -> None: |
|
""" |
|
Calculates the image embeddings for the provided image, allowing |
|
masks to be predicted with the 'predict' method. Expects the input |
|
image to be already transformed to the format expected by the model. |
|
|
|
Arguments: |
|
transformed_image (torch.Tensor): The input image, with shape |
|
1x3xHxW, which has been transformed with ResizeLongestSide. |
|
original_image_size (tuple(int, int)): The size of the image |
|
before transformation, in (H, W) format. |
|
""" |
|
assert ( |
|
len(transformed_image.shape) == 4 |
|
and transformed_image.shape[1] == 3 |
|
and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size |
|
), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." |
|
self.reset_image() |
|
|
|
self.original_size = original_image_size |
|
self.input_size = tuple(transformed_image.shape[-2:]) |
|
input_image = self.model.preprocess(transformed_image) |
|
if self.is_hq: |
|
self.features, self.interm_features = self.model.image_encoder(input_image) |
|
else: |
|
self.features = self.model.image_encoder(input_image) |
|
self.is_image_set = True |
|
|
|
|
|
@torch.no_grad() |
|
def predict_torch( |
|
self, |
|
point_coords: Optional[torch.Tensor], |
|
point_labels: Optional[torch.Tensor], |
|
boxes: Optional[torch.Tensor] = None, |
|
mask_input: Optional[torch.Tensor] = None, |
|
multimask_output: bool = True, |
|
return_logits: bool = False, |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
""" |
|
Predict masks for the given input prompts, using the currently set image. |
|
Input prompts are batched torch tensors and are expected to already be |
|
transformed to the input frame using ResizeLongestSide. |
|
|
|
Arguments: |
|
point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the |
|
model. Each point is in (X,Y) in pixels. |
|
point_labels (torch.Tensor or None): A BxN array of labels for the |
|
point prompts. 1 indicates a foreground point and 0 indicates a |
|
background point. |
|
boxes (np.ndarray or None): A Bx4 array given a box prompt to the |
|
model, in XYXY format. |
|
mask_input (np.ndarray): A low resolution mask input to the model, typically |
|
coming from a previous prediction iteration. Has form Bx1xHxW, where |
|
for SAM, H=W=256. Masks returned by a previous iteration of the |
|
predict method do not need further transformation. |
|
multimask_output (bool): If true, the model will return three masks. |
|
For ambiguous input prompts (such as a single click), this will often |
|
produce better masks than a single prediction. If only a single |
|
mask is needed, the model's predicted quality score can be used |
|
to select the best mask. For non-ambiguous prompts, such as multiple |
|
input prompts, multimask_output=False can give better results. |
|
return_logits (bool): If true, returns un-thresholded masks logits |
|
instead of a binary mask. |
|
|
|
Returns: |
|
(torch.Tensor): The output masks in BxCxHxW format, where C is the |
|
number of masks, and (H, W) is the original image size. |
|
(torch.Tensor): An array of shape BxC containing the model's |
|
predictions for the quality of each mask. |
|
(torch.Tensor): An array of shape BxCxHxW, where C is the number |
|
of masks and H=W=256. These low res logits can be passed to |
|
a subsequent iteration as mask input. |
|
""" |
|
if not self.is_image_set: |
|
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") |
|
|
|
if point_coords is not None: |
|
points = (point_coords, point_labels) |
|
else: |
|
points = None |
|
|
|
|
|
sparse_embeddings, dense_embeddings = self.model.prompt_encoder( |
|
points=points, |
|
boxes=boxes, |
|
masks=mask_input, |
|
) |
|
|
|
|
|
if self.is_hq: |
|
low_res_masks, iou_predictions = self.model.mask_decoder( |
|
image_embeddings=self.features, |
|
image_pe=self.model.prompt_encoder.get_dense_pe(), |
|
sparse_prompt_embeddings=sparse_embeddings, |
|
dense_prompt_embeddings=dense_embeddings, |
|
multimask_output=multimask_output, |
|
hq_token_only=False, |
|
interm_embeddings=self.interm_features, |
|
) |
|
else: |
|
low_res_masks, iou_predictions = self.model.mask_decoder( |
|
image_embeddings=self.features, |
|
image_pe=self.model.prompt_encoder.get_dense_pe(), |
|
sparse_prompt_embeddings=sparse_embeddings, |
|
dense_prompt_embeddings=dense_embeddings, |
|
multimask_output=multimask_output, |
|
) |
|
|
|
masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) |
|
|
|
if not return_logits: |
|
masks = masks > self.model.mask_threshold |
|
|
|
return masks, iou_predictions, low_res_masks |
|
|