import torch import numpy as np from torchvision.transforms import ToTensor GPU_EFFICIENT_SAM_CHECKPOINT = "efficient_sam_s_gpu.jit" CPU_EFFICIENT_SAM_CHECKPOINT = "efficient_sam_s_cpu.jit" def load(device: torch.device) -> torch.jit.ScriptModule: if device.type == "cuda": model = torch.jit.load(GPU_EFFICIENT_SAM_CHECKPOINT) else: model = torch.jit.load(CPU_EFFICIENT_SAM_CHECKPOINT) model.eval() return model def inference_with_box( image: np.ndarray, box: np.ndarray, model: torch.jit.ScriptModule, device: torch.device ) -> np.ndarray: bbox = torch.reshape(torch.tensor(box), [1, 1, 2, 2]) bbox_labels = torch.reshape(torch.tensor([2, 3]), [1, 1, 2]) img_tensor = ToTensor()(image) predicted_logits, predicted_iou = model( img_tensor[None, ...].to(device), bbox.to(device), bbox_labels.to(device), ) predicted_logits = predicted_logits.cpu() all_masks = torch.ge(torch.sigmoid(predicted_logits[0, 0, :, :, :]), 0.5).numpy() predicted_iou = predicted_iou[0, 0, ...].cpu().detach().numpy() max_predicted_iou = -1 selected_mask_using_predicted_iou = None for m in range(all_masks.shape[0]): curr_predicted_iou = predicted_iou[m] if ( curr_predicted_iou > max_predicted_iou or selected_mask_using_predicted_iou is None ): max_predicted_iou = curr_predicted_iou selected_mask_using_predicted_iou = all_masks[m] return selected_mask_using_predicted_iou def inference_with_point( image: np.ndarray, point: np.ndarray, model: torch.jit.ScriptModule, device: torch.device ) -> np.ndarray: pts_sampled = torch.reshape(torch.tensor(point), [1, 1, -1, 2]) max_num_pts = pts_sampled.shape[2] pts_labels = torch.ones(1, 1, max_num_pts) img_tensor = ToTensor()(image) predicted_logits, predicted_iou = model( img_tensor[None, ...].to(device), pts_sampled.to(device), pts_labels.to(device), ) predicted_logits = predicted_logits.cpu() all_masks = torch.ge(torch.sigmoid(predicted_logits[0, 0, :, :, :]), 0.5).numpy() predicted_iou = predicted_iou[0, 0, ...].cpu().detach().numpy() max_predicted_iou = -1 selected_mask_using_predicted_iou = None for m in range(all_masks.shape[0]): curr_predicted_iou = predicted_iou[m] if ( curr_predicted_iou > max_predicted_iou or selected_mask_using_predicted_iou is None ): max_predicted_iou = curr_predicted_iou selected_mask_using_predicted_iou = all_masks[m] return selected_mask_using_predicted_iou