Spaces:
Running
Running
import os | |
import logging | |
import torch | |
import cv2 | |
import numpy as np | |
from typing import List, Dict, Optional | |
from label_studio_ml.utils import get_image_local_path, InMemoryLRUDictCache | |
logger = logging.getLogger(__name__) | |
VITH_CHECKPOINT = os.environ.get("VITH_CHECKPOINT") | |
ONNX_CHECKPOINT = os.environ.get("ONNX_CHECKPOINT") | |
MOBILESAM_CHECKPOINT = os.environ.get("MOBILESAM_CHECKPOINT", "mobile_sam.pt") | |
LABEL_STUDIO_ACCESS_TOKEN = os.environ.get("LABEL_STUDIO_ACCESS_TOKEN") | |
LABEL_STUDIO_HOST = os.environ.get("LABEL_STUDIO_HOST") | |
class SAMPredictor(object): | |
def __init__(self, model_choice): | |
self.model_choice = model_choice | |
# cache for embeddings | |
# TODO: currently it supports only one image in cache, | |
# since predictor.set_image() should be called each time the new image comes | |
# before making predictions | |
# to extend it to >1 image, we need to store the "active image" state in the cache | |
self.cache = InMemoryLRUDictCache(1) | |
# if you're not using CUDA, use "cpu" instead .... good luck not burning your computer lol | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.debug(f"Using device {self.device}") | |
if model_choice == 'ONNX': | |
import onnxruntime | |
from segment_anything import sam_model_registry, SamPredictor | |
self.model_checkpoint = VITH_CHECKPOINT | |
if self.model_checkpoint is None: | |
raise FileNotFoundError("VITH_CHECKPOINT is not set: please set it to the path to the SAM checkpoint") | |
if ONNX_CHECKPOINT is None: | |
raise FileNotFoundError("ONNX_CHECKPOINT is not set: please set it to the path to the ONNX checkpoint") | |
logger.info(f"Using ONNX checkpoint {ONNX_CHECKPOINT} and SAM checkpoint {self.model_checkpoint}") | |
self.ort = onnxruntime.InferenceSession(ONNX_CHECKPOINT) | |
reg_key = "vit_h" | |
elif model_choice == 'SAM': | |
from segment_anything import SamPredictor, sam_model_registry | |
self.model_checkpoint = VITH_CHECKPOINT | |
if self.model_checkpoint is None: | |
raise FileNotFoundError("VITH_CHECKPOINT is not set: please set it to the path to the SAM checkpoint") | |
logger.info(f"Using SAM checkpoint {self.model_checkpoint}") | |
reg_key = "vit_h" | |
elif model_choice == 'MobileSAM': | |
from mobile_sam import SamPredictor, sam_model_registry | |
self.model_checkpoint = MOBILESAM_CHECKPOINT | |
if not self.model_checkpoint: | |
raise FileNotFoundError("MOBILE_CHECKPOINT is not set: please set it to the path to the MobileSAM checkpoint") | |
logger.info(f"Using MobileSAM checkpoint {self.model_checkpoint}") | |
reg_key = 'vit_t' | |
else: | |
raise ValueError(f"Invalid model choice {model_choice}") | |
sam = sam_model_registry[reg_key](checkpoint=self.model_checkpoint) | |
sam.to(device=self.device) | |
self.predictor = SamPredictor(sam) | |
def model_name(self): | |
return f'{self.model_choice}:{self.model_checkpoint}:{self.device}' | |
def set_image(self, img_path, calculate_embeddings=True): | |
payload = self.cache.get(img_path) | |
if payload is None: | |
# Get image and embeddings | |
logger.debug(f'Payload not found for {img_path} in `IN_MEM_CACHE`: calculating from scratch') | |
image_path = get_image_local_path( | |
img_path, | |
label_studio_access_token=LABEL_STUDIO_ACCESS_TOKEN, | |
label_studio_host=LABEL_STUDIO_HOST | |
) | |
image = cv2.imread(image_path) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
self.predictor.set_image(image) | |
payload = {'image_shape': image.shape[:2]} | |
logger.debug(f'Finished set_image({img_path}) in `IN_MEM_CACHE`: image shape {image.shape[:2]}') | |
if calculate_embeddings: | |
image_embedding = self.predictor.get_image_embedding().cpu().numpy() | |
payload['image_embedding'] = image_embedding | |
logger.debug(f'Finished storing embeddings for {img_path} in `IN_MEM_CACHE`: ' | |
f'embedding shape {image_embedding.shape}') | |
self.cache.put(img_path, payload) | |
else: | |
logger.debug(f"Using embeddings for {img_path} from `IN_MEM_CACHE`") | |
return payload | |
def predict_onnx( | |
self, | |
img_path, | |
point_coords: Optional[List[List]] = None, | |
point_labels: Optional[List] = None, | |
input_box: Optional[List] = None | |
): | |
# calculate embeddings | |
payload = self.set_image(img_path, calculate_embeddings=True) | |
image_shape = payload['image_shape'] | |
image_embedding = payload['image_embedding'] | |
onnx_point_coords = np.array(point_coords, dtype=np.float32) if point_coords else None | |
onnx_point_labels = np.array(point_labels, dtype=np.float32) if point_labels else None | |
onnx_box_coords = np.array(input_box, dtype=np.float32).reshape(2, 2) if input_box else None | |
onnx_coords, onnx_labels = None, None | |
if onnx_point_coords is not None and onnx_box_coords is not None: | |
# both keypoints and boxes are present | |
onnx_coords = np.concatenate([onnx_point_coords, onnx_box_coords], axis=0)[None, :, :] | |
onnx_labels = np.concatenate([onnx_point_labels, np.array([2, 3])], axis=0)[None, :].astype(np.float32) | |
elif onnx_point_coords is not None: | |
# only keypoints are present | |
onnx_coords = np.concatenate([onnx_point_coords, np.array([[0.0, 0.0]])], axis=0)[None, :, :] | |
onnx_labels = np.concatenate([onnx_point_labels, np.array([-1])], axis=0)[None, :].astype(np.float32) | |
elif onnx_box_coords is not None: | |
# only boxes are present | |
raise NotImplementedError("Boxes without keypoints are not supported yet") | |
onnx_coords = self.predictor.transform.apply_coords(onnx_coords, image_shape).astype(np.float32) | |
# TODO: support mask inputs | |
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) | |
onnx_has_mask_input = np.zeros(1, dtype=np.float32) | |
ort_inputs = { | |
"image_embeddings": image_embedding, | |
"point_coords": onnx_coords, | |
"point_labels": onnx_labels, | |
"mask_input": onnx_mask_input, | |
"has_mask_input": onnx_has_mask_input, | |
"orig_im_size": np.array(image_shape, dtype=np.float32) | |
} | |
masks, prob, low_res_logits = self.ort.run(None, ort_inputs) | |
masks = masks > self.predictor.model.mask_threshold | |
mask = masks[0, 0, :, :].astype(np.uint8) # each mask has shape [H, W] | |
prob = float(prob[0][0]) | |
# TODO: support the real multimask output as in https://github.com/facebookresearch/segment-anything/blob/main/notebooks/predictor_example.ipynb | |
return { | |
'masks': [mask], | |
'probs': [prob] | |
} | |
def predict_sam( | |
self, | |
img_path, | |
point_coords: Optional[List[List]] = None, | |
point_labels: Optional[List] = None, | |
input_box: Optional[List] = None | |
): | |
self.set_image(img_path, calculate_embeddings=False) | |
point_coords = np.array(point_coords, dtype=np.float32) if point_coords else None | |
point_labels = np.array(point_labels, dtype=np.float32) if point_labels else None | |
input_box = np.array(input_box, dtype=np.float32) if input_box else None | |
masks, probs, logits = self.predictor.predict( | |
point_coords=point_coords, | |
point_labels=point_labels, | |
box=input_box, | |
# TODO: support multimask output | |
multimask_output=False | |
) | |
mask = masks[0, :, :].astype(np.uint8) # each mask has shape [H, W] | |
prob = float(probs[0]) | |
return { | |
'masks': [mask], | |
'probs': [prob] | |
} | |
def predict( | |
self, img_path: str, | |
point_coords: Optional[List[List]] = None, | |
point_labels: Optional[List] = None, | |
input_box: Optional[List] = None | |
): | |
if self.model_choice == 'ONNX': | |
return self.predict_onnx(img_path, point_coords, point_labels, input_box) | |
elif self.model_choice in ('SAM', 'MobileSAM'): | |
return self.predict_sam(img_path, point_coords, point_labels, input_box) | |
else: | |
raise NotImplementedError(f"Model choice {self.model_choice} is not supported yet") | |