dillonlaird's picture
fixed name error
153ee9b
import numpy.typing as npt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
from torchvision.ops.boxes import batched_nms
from app.mobile_sam import SamPredictor
from app.mobile_sam.utils import batched_mask_to_box
from app.sam.postprocess import clean_mask_torch
def point_selection(mask_sim, topk: int = 1):
# Top-1 point selection
_, h = mask_sim.shape
topk_xy = mask_sim.flatten(0).topk(topk)[1]
topk_x = (topk_xy // h).unsqueeze(0)
topk_y = topk_xy - topk_x * h
topk_xy = torch.cat((topk_y, topk_x), dim=0).permute(1, 0)
topk_label = np.array([1] * topk)
topk_xy = topk_xy.cpu().numpy()
return topk_xy, topk_label
def mask_nms(
masks: list[npt.NDArray], scores: list[float], iou_thresh: float = 0.2
) -> tuple[list[npt.NDArray], list[float]]:
ious = np.zeros((len(masks), len(masks)))
np_masks = np.array(masks).astype(bool)
np_scores = np.array(scores)
remove_indices = set()
for i in range(len(masks)):
mask_i = np_masks[i, :, :]
intersection_sum = np.logical_and(mask_i, np_masks).sum(axis=(1, 2))
union = np.logical_or(mask_i, np_masks)
ious_i = intersection_sum / union.sum(axis=(1, 2))
ious[i, :] = ious_i
# if the mask completely overlaps another mask, take the highest
# scoring mask and remove the lower (current) one
overlap = intersection_sum >= np_masks.sum(axis=(1, 2)) * 0.90
argmax_idx = np_scores[overlap].argmax()
max_idx = np.where(overlap == True)[0][argmax_idx]
if max_idx != i:
remove_indices.add(i)
for i in range(ious.shape[0]):
ious_i = ious[i, :]
idxs = np.where(ious_i > iou_thresh)[0]
keep = idxs[np.argmax(np_scores[idxs])]
if keep != i:
remove_indices.add(i)
return [masks[i] for i in range(len(masks)) if i not in remove_indices], [
scores[i] for i in range(len(masks)) if i not in remove_indices
]
class MaskWeights(nn.Module):
def __init__(self):
super().__init__()
self.weights = nn.Parameter(torch.ones(2, 1, requires_grad=True) / 3)
class PerSAM:
def __init__(
self,
sam: SamPredictor,
target_feat: torch.Tensor,
max_objects: int,
score_thresh: float,
nms_iou_thresh: float,
mask_weights: torch.Tensor,
) -> None:
super().__init__()
self.sam = sam
self.weights = mask_weights
self.target_feat = target_feat
self.max_objects = max_objects
self.score_thresh = score_thresh
self.nms_iou_thresh = nms_iou_thresh
def __call__(self, x: npt.NDArray) -> tuple[npt.NDArray, npt.NDArray, npt.NDArray]:
return fast_inference(
self.sam,
x,
self.target_feat,
self.weights,
self.max_objects,
self.score_thresh,
self.nms_iou_thresh,
)
def fast_inference(
predictor: SamPredictor,
image: npt.NDArray,
target_feat: torch.Tensor,
weights: torch.Tensor,
max_objects: int,
score_thresh: float,
nms_iou_thresh: float = 0.2,
) -> tuple[npt.NDArray | None, npt.NDArray | None, npt.NDArray | None]:
weights_np = weights.detach().cpu().numpy()
pred_masks = []
pred_scores = []
# Image feature encoding
predictor.set_image(image)
test_feat = predictor.features.squeeze()
# Cosine similarity
C, h, w = test_feat.shape
test_feat = test_feat / test_feat.norm(dim=0, keepdim=True)
test_feat = test_feat.reshape(C, h * w)
sim = target_feat @ test_feat
sim = sim.reshape(1, 1, h, w)
sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
sim = predictor.model.postprocess_masks(
sim, input_size=predictor.input_size, original_size=predictor.original_size
).squeeze()
for _ in range(max_objects):
# Positive location prior
topk_xy, topk_label = point_selection(sim, topk=1)
# First-step prediction
logits_high, scores, logits = predictor.predict(
point_coords=topk_xy,
point_labels=topk_label,
multimask_output=True,
return_logits=True,
return_numpy=False,
)
logits = logits.detach().cpu().numpy()
# Weighted sum three-scale masks
logits_high = logits_high * weights.unsqueeze(-1)
logit_high = logits_high.sum(0)
# mask = (logit_high > 0).detach().cpu().numpy()
mask = (logit_high > 0)
mask = clean_mask_torch(mask).bool()[0, 0, :, :].detach().cpu().numpy()
logits = logits * weights_np[..., None]
logit = logits.sum(0)
# Cascaded Post-refinement-1
y, x = np.nonzero(mask)
x_min = x.min()
x_max = x.max()
y_min = y.min()
y_max = y.max()
input_box = np.array([x_min, y_min, x_max, y_max])
masks, scores, logits = predictor.predict(
point_coords=topk_xy,
point_labels=topk_label,
box=input_box[None, :],
mask_input=logit[None, :, :],
multimask_output=True,
)
best_idx = np.argmax(scores)
# Cascaded Post-refinement-2
y, x = np.nonzero(masks[best_idx])
x_min = x.min()
x_max = x.max()
y_min = y.min()
y_max = y.max()
input_box = np.array([x_min, y_min, x_max, y_max])
masks, scores, logits = predictor.predict(
point_coords=topk_xy,
point_labels=topk_label,
box=input_box[None, :],
mask_input=logits[best_idx : best_idx + 1, :, :],
multimask_output=True,
return_numpy=False,
)
best_idx = np.argmax(scores.detach().cpu().numpy())
final_mask = masks[best_idx]
score = sim[topk_xy[0][1], topk_xy[0][0]].item()
final_mask_dilate = cv2.dilate(
final_mask.detach().cpu().numpy().astype(np.uint8), np.ones((5, 5), np.uint8), iterations=1
)
if score < score_thresh:
break
sim[final_mask_dilate] = 0
pred_masks.append(final_mask)
pred_scores.append(score)
if len(pred_masks) == 0:
return None, None, None
pred_masks = torch.stack(pred_masks)
bboxes = batched_mask_to_box(pred_masks)
keep_by_nms = batched_nms(
bboxes.float(),
torch.as_tensor(pred_scores),
torch.zeros_like(bboxes[:, 0]),
iou_threshold=nms_iou_thresh,
)
pred_masks = pred_masks[keep_by_nms].cpu().numpy()
pred_scores = np.array(pred_scores)[keep_by_nms.cpu().numpy()]
bboxes = bboxes[keep_by_nms].int().cpu().numpy()
return pred_masks, bboxes, pred_scores