Spaces:
Build error
Build error
import numpy.typing as npt | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import cv2 | |
from torchvision.ops.boxes import batched_nms | |
from app.mobile_sam import SamPredictor | |
from app.mobile_sam.utils import batched_mask_to_box | |
from app.sam.postprocess import clean_mask_torch | |
def point_selection(mask_sim, topk: int = 1): | |
# Top-1 point selection | |
_, h = mask_sim.shape | |
topk_xy = mask_sim.flatten(0).topk(topk)[1] | |
topk_x = (topk_xy // h).unsqueeze(0) | |
topk_y = topk_xy - topk_x * h | |
topk_xy = torch.cat((topk_y, topk_x), dim=0).permute(1, 0) | |
topk_label = np.array([1] * topk) | |
topk_xy = topk_xy.cpu().numpy() | |
return topk_xy, topk_label | |
def mask_nms( | |
masks: list[npt.NDArray], scores: list[float], iou_thresh: float = 0.2 | |
) -> tuple[list[npt.NDArray], list[float]]: | |
ious = np.zeros((len(masks), len(masks))) | |
np_masks = np.array(masks).astype(bool) | |
np_scores = np.array(scores) | |
remove_indices = set() | |
for i in range(len(masks)): | |
mask_i = np_masks[i, :, :] | |
intersection_sum = np.logical_and(mask_i, np_masks).sum(axis=(1, 2)) | |
union = np.logical_or(mask_i, np_masks) | |
ious_i = intersection_sum / union.sum(axis=(1, 2)) | |
ious[i, :] = ious_i | |
# if the mask completely overlaps another mask, take the highest | |
# scoring mask and remove the lower (current) one | |
overlap = intersection_sum >= np_masks.sum(axis=(1, 2)) * 0.90 | |
argmax_idx = np_scores[overlap].argmax() | |
max_idx = np.where(overlap == True)[0][argmax_idx] | |
if max_idx != i: | |
remove_indices.add(i) | |
for i in range(ious.shape[0]): | |
ious_i = ious[i, :] | |
idxs = np.where(ious_i > iou_thresh)[0] | |
keep = idxs[np.argmax(np_scores[idxs])] | |
if keep != i: | |
remove_indices.add(i) | |
return [masks[i] for i in range(len(masks)) if i not in remove_indices], [ | |
scores[i] for i in range(len(masks)) if i not in remove_indices | |
] | |
class MaskWeights(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.weights = nn.Parameter(torch.ones(2, 1, requires_grad=True) / 3) | |
class PerSAM: | |
def __init__( | |
self, | |
sam: SamPredictor, | |
target_feat: torch.Tensor, | |
max_objects: int, | |
score_thresh: float, | |
nms_iou_thresh: float, | |
mask_weights: torch.Tensor, | |
) -> None: | |
super().__init__() | |
self.sam = sam | |
self.weights = mask_weights | |
self.target_feat = target_feat | |
self.max_objects = max_objects | |
self.score_thresh = score_thresh | |
self.nms_iou_thresh = nms_iou_thresh | |
def __call__(self, x: npt.NDArray) -> tuple[npt.NDArray, npt.NDArray, npt.NDArray]: | |
return fast_inference( | |
self.sam, | |
x, | |
self.target_feat, | |
self.weights, | |
self.max_objects, | |
self.score_thresh, | |
self.nms_iou_thresh, | |
) | |
def fast_inference( | |
predictor: SamPredictor, | |
image: npt.NDArray, | |
target_feat: torch.Tensor, | |
weights: torch.Tensor, | |
max_objects: int, | |
score_thresh: float, | |
nms_iou_thresh: float = 0.2, | |
) -> tuple[npt.NDArray | None, npt.NDArray | None, npt.NDArray | None]: | |
weights_np = weights.detach().cpu().numpy() | |
pred_masks = [] | |
pred_scores = [] | |
# Image feature encoding | |
predictor.set_image(image) | |
test_feat = predictor.features.squeeze() | |
# Cosine similarity | |
C, h, w = test_feat.shape | |
test_feat = test_feat / test_feat.norm(dim=0, keepdim=True) | |
test_feat = test_feat.reshape(C, h * w) | |
sim = target_feat @ test_feat | |
sim = sim.reshape(1, 1, h, w) | |
sim = F.interpolate(sim, scale_factor=4, mode="bilinear") | |
sim = predictor.model.postprocess_masks( | |
sim, input_size=predictor.input_size, original_size=predictor.original_size | |
).squeeze() | |
for _ in range(max_objects): | |
# Positive location prior | |
topk_xy, topk_label = point_selection(sim, topk=1) | |
# First-step prediction | |
logits_high, scores, logits = predictor.predict( | |
point_coords=topk_xy, | |
point_labels=topk_label, | |
multimask_output=True, | |
return_logits=True, | |
return_numpy=False, | |
) | |
logits = logits.detach().cpu().numpy() | |
# Weighted sum three-scale masks | |
logits_high = logits_high * weights.unsqueeze(-1) | |
logit_high = logits_high.sum(0) | |
# mask = (logit_high > 0).detach().cpu().numpy() | |
mask = (logit_high > 0) | |
mask = clean_mask_torch(mask).bool()[0, 0, :, :].detach().cpu().numpy() | |
logits = logits * weights_np[..., None] | |
logit = logits.sum(0) | |
# Cascaded Post-refinement-1 | |
y, x = np.nonzero(mask) | |
x_min = x.min() | |
x_max = x.max() | |
y_min = y.min() | |
y_max = y.max() | |
input_box = np.array([x_min, y_min, x_max, y_max]) | |
masks, scores, logits = predictor.predict( | |
point_coords=topk_xy, | |
point_labels=topk_label, | |
box=input_box[None, :], | |
mask_input=logit[None, :, :], | |
multimask_output=True, | |
) | |
best_idx = np.argmax(scores) | |
# Cascaded Post-refinement-2 | |
y, x = np.nonzero(masks[best_idx]) | |
x_min = x.min() | |
x_max = x.max() | |
y_min = y.min() | |
y_max = y.max() | |
input_box = np.array([x_min, y_min, x_max, y_max]) | |
masks, scores, logits = predictor.predict( | |
point_coords=topk_xy, | |
point_labels=topk_label, | |
box=input_box[None, :], | |
mask_input=logits[best_idx : best_idx + 1, :, :], | |
multimask_output=True, | |
return_numpy=False, | |
) | |
best_idx = np.argmax(scores.detach().cpu().numpy()) | |
final_mask = masks[best_idx] | |
score = sim[topk_xy[0][1], topk_xy[0][0]].item() | |
final_mask_dilate = cv2.dilate( | |
final_mask.detach().cpu().numpy().astype(np.uint8), np.ones((5, 5), np.uint8), iterations=1 | |
) | |
if score < score_thresh: | |
break | |
sim[final_mask_dilate] = 0 | |
pred_masks.append(final_mask) | |
pred_scores.append(score) | |
if len(pred_masks) == 0: | |
return None, None, None | |
pred_masks = torch.stack(pred_masks) | |
bboxes = batched_mask_to_box(pred_masks) | |
keep_by_nms = batched_nms( | |
bboxes.float(), | |
torch.as_tensor(pred_scores), | |
torch.zeros_like(bboxes[:, 0]), | |
iou_threshold=nms_iou_thresh, | |
) | |
pred_masks = pred_masks[keep_by_nms].cpu().numpy() | |
pred_scores = np.array(pred_scores)[keep_by_nms.cpu().numpy()] | |
bboxes = bboxes[keep_by_nms].int().cpu().numpy() | |
return pred_masks, bboxes, pred_scores | |