|
import base64 |
|
from io import BytesIO |
|
from time import perf_counter |
|
from typing import Any, List, Optional, Union |
|
|
|
import numpy as np |
|
import onnxruntime |
|
import rasterio.features |
|
import torch |
|
from segment_anything import SamPredictor, sam_model_registry |
|
from shapely.geometry import Polygon as ShapelyPolygon |
|
|
|
from inference.core.entities.requests.inference import InferenceRequestImage |
|
from inference.core.entities.requests.sam import ( |
|
SamEmbeddingRequest, |
|
SamInferenceRequest, |
|
SamSegmentationRequest, |
|
) |
|
from inference.core.entities.responses.sam import ( |
|
SamEmbeddingResponse, |
|
SamSegmentationResponse, |
|
) |
|
from inference.core.env import SAM_MAX_EMBEDDING_CACHE_SIZE, SAM_VERSION_ID |
|
from inference.core.models.roboflow import RoboflowCoreModel |
|
from inference.core.utils.image_utils import load_image_rgb |
|
from inference.core.utils.postprocess import masks2poly |
|
|
|
|
|
class SegmentAnything(RoboflowCoreModel): |
|
"""SegmentAnything class for handling segmentation tasks. |
|
|
|
Attributes: |
|
sam: The segmentation model. |
|
predictor: The predictor for the segmentation model. |
|
ort_session: ONNX runtime inference session. |
|
embedding_cache: Cache for embeddings. |
|
image_size_cache: Cache for image sizes. |
|
embedding_cache_keys: Keys for the embedding cache. |
|
low_res_logits_cache: Cache for low resolution logits. |
|
segmentation_cache_keys: Keys for the segmentation cache. |
|
""" |
|
|
|
def __init__(self, *args, model_id: str = f"sam/{SAM_VERSION_ID}", **kwargs): |
|
"""Initializes the SegmentAnything. |
|
|
|
Args: |
|
*args: Variable length argument list. |
|
**kwargs: Arbitrary keyword arguments. |
|
""" |
|
super().__init__(*args, model_id=model_id, **kwargs) |
|
self.sam = sam_model_registry[self.version_id]( |
|
checkpoint=self.cache_file("encoder.pth") |
|
) |
|
self.sam.to(device="cuda" if torch.cuda.is_available() else "cpu") |
|
self.predictor = SamPredictor(self.sam) |
|
self.ort_session = onnxruntime.InferenceSession( |
|
self.cache_file("decoder.onnx"), |
|
providers=[ |
|
"CUDAExecutionProvider", |
|
"CPUExecutionProvider", |
|
], |
|
) |
|
self.embedding_cache = {} |
|
self.image_size_cache = {} |
|
self.embedding_cache_keys = [] |
|
|
|
self.low_res_logits_cache = {} |
|
self.segmentation_cache_keys = [] |
|
self.task_type = "unsupervised-segmentation" |
|
|
|
def get_infer_bucket_file_list(self) -> List[str]: |
|
"""Gets the list of files required for inference. |
|
|
|
Returns: |
|
List[str]: List of file names. |
|
""" |
|
return ["encoder.pth", "decoder.onnx"] |
|
|
|
def embed_image(self, image: Any, image_id: Optional[str] = None, **kwargs): |
|
""" |
|
Embeds an image and caches the result if an image_id is provided. If the image has been embedded before and cached, |
|
the cached result will be returned. |
|
|
|
Args: |
|
image (Any): The image to be embedded. The format should be compatible with the preproc_image method. |
|
image_id (Optional[str]): An identifier for the image. If provided, the embedding result will be cached |
|
with this ID. Defaults to None. |
|
**kwargs: Additional keyword arguments. |
|
|
|
Returns: |
|
Tuple[np.ndarray, Tuple[int, int]]: A tuple where the first element is the embedding of the image |
|
and the second element is the shape (height, width) of the processed image. |
|
|
|
Notes: |
|
- Embeddings and image sizes are cached to improve performance on repeated requests for the same image. |
|
- The cache has a maximum size defined by SAM_MAX_EMBEDDING_CACHE_SIZE. When the cache exceeds this size, |
|
the oldest entries are removed. |
|
|
|
Example: |
|
>>> img_array = ... # some image array |
|
>>> embed_image(img_array, image_id="sample123") |
|
(array([...]), (224, 224)) |
|
""" |
|
if image_id and image_id in self.embedding_cache: |
|
return ( |
|
self.embedding_cache[image_id], |
|
self.image_size_cache[image_id], |
|
) |
|
img_in = self.preproc_image(image) |
|
self.predictor.set_image(img_in) |
|
embedding = self.predictor.get_image_embedding().cpu().numpy() |
|
if image_id: |
|
self.embedding_cache[image_id] = embedding |
|
self.image_size_cache[image_id] = img_in.shape[:2] |
|
self.embedding_cache_keys.append(image_id) |
|
if len(self.embedding_cache_keys) > SAM_MAX_EMBEDDING_CACHE_SIZE: |
|
cache_key = self.embedding_cache_keys.pop(0) |
|
del self.embedding_cache[cache_key] |
|
del self.image_size_cache[cache_key] |
|
return (embedding, img_in.shape[:2]) |
|
|
|
def infer_from_request(self, request: SamInferenceRequest): |
|
"""Performs inference based on the request type. |
|
|
|
Args: |
|
request (SamInferenceRequest): The inference request. |
|
|
|
Returns: |
|
Union[SamEmbeddingResponse, SamSegmentationResponse]: The inference response. |
|
""" |
|
t1 = perf_counter() |
|
if isinstance(request, SamEmbeddingRequest): |
|
embedding, _ = self.embed_image(**request.dict()) |
|
inference_time = perf_counter() - t1 |
|
if request.format == "json": |
|
return SamEmbeddingResponse( |
|
embeddings=embedding.tolist(), time=inference_time |
|
) |
|
elif request.format == "binary": |
|
binary_vector = BytesIO() |
|
np.save(binary_vector, embedding) |
|
binary_vector.seek(0) |
|
return SamEmbeddingResponse( |
|
embeddings=binary_vector.getvalue(), time=inference_time |
|
) |
|
elif isinstance(request, SamSegmentationRequest): |
|
masks, low_res_masks = self.segment_image(**request.dict()) |
|
if request.format == "json": |
|
masks = masks > self.predictor.model.mask_threshold |
|
masks = masks2poly(masks) |
|
low_res_masks = low_res_masks > self.predictor.model.mask_threshold |
|
low_res_masks = masks2poly(low_res_masks) |
|
elif request.format == "binary": |
|
binary_vector = BytesIO() |
|
np.savez_compressed( |
|
binary_vector, masks=masks, low_res_masks=low_res_masks |
|
) |
|
binary_vector.seek(0) |
|
binary_data = binary_vector.getvalue() |
|
return binary_data |
|
else: |
|
raise ValueError(f"Invalid format {request.format}") |
|
|
|
response = SamSegmentationResponse( |
|
masks=[m.tolist() for m in masks], |
|
low_res_masks=[m.tolist() for m in low_res_masks], |
|
time=perf_counter() - t1, |
|
) |
|
return response |
|
|
|
def preproc_image(self, image: InferenceRequestImage): |
|
"""Preprocesses an image. |
|
|
|
Args: |
|
image (InferenceRequestImage): The image to preprocess. |
|
|
|
Returns: |
|
np.array: The preprocessed image. |
|
""" |
|
np_image = load_image_rgb(image) |
|
return np_image |
|
|
|
def segment_image( |
|
self, |
|
image: Any, |
|
embeddings: Optional[Union[np.ndarray, List[List[float]]]] = None, |
|
embeddings_format: Optional[str] = "json", |
|
has_mask_input: Optional[bool] = False, |
|
image_id: Optional[str] = None, |
|
mask_input: Optional[Union[np.ndarray, List[List[List[float]]]]] = None, |
|
mask_input_format: Optional[str] = "json", |
|
orig_im_size: Optional[List[int]] = None, |
|
point_coords: Optional[List[List[float]]] = [], |
|
point_labels: Optional[List[int]] = [], |
|
use_mask_input_cache: Optional[bool] = True, |
|
**kwargs, |
|
): |
|
""" |
|
Segments an image based on provided embeddings, points, masks, or cached results. |
|
If embeddings are not directly provided, the function can derive them from the input image or cache. |
|
|
|
Args: |
|
image (Any): The image to be segmented. |
|
embeddings (Optional[Union[np.ndarray, List[List[float]]]]): The embeddings of the image. |
|
Defaults to None, in which case the image is used to compute embeddings. |
|
embeddings_format (Optional[str]): Format of the provided embeddings; either 'json' or 'binary'. Defaults to 'json'. |
|
has_mask_input (Optional[bool]): Specifies whether mask input is provided. Defaults to False. |
|
image_id (Optional[str]): A cached identifier for the image. Useful for accessing cached embeddings or masks. |
|
mask_input (Optional[Union[np.ndarray, List[List[List[float]]]]]): Input mask for the image. |
|
mask_input_format (Optional[str]): Format of the provided mask input; either 'json' or 'binary'. Defaults to 'json'. |
|
orig_im_size (Optional[List[int]]): Original size of the image when providing embeddings directly. |
|
point_coords (Optional[List[List[float]]]): Coordinates of points in the image. Defaults to an empty list. |
|
point_labels (Optional[List[int]]): Labels associated with the provided points. Defaults to an empty list. |
|
use_mask_input_cache (Optional[bool]): Flag to determine if cached mask input should be used. Defaults to True. |
|
**kwargs: Additional keyword arguments. |
|
|
|
Returns: |
|
Tuple[np.ndarray, np.ndarray]: A tuple where the first element is the segmentation masks of the image |
|
and the second element is the low resolution segmentation masks. |
|
|
|
Raises: |
|
ValueError: If necessary inputs are missing or inconsistent. |
|
|
|
Notes: |
|
- Embeddings, segmentations, and low-resolution logits can be cached to improve performance |
|
on repeated requests for the same image. |
|
- The cache has a maximum size defined by SAM_MAX_EMBEDDING_CACHE_SIZE. When the cache exceeds this size, |
|
the oldest entries are removed. |
|
""" |
|
if not embeddings: |
|
if not image and not image_id: |
|
raise ValueError( |
|
"Must provide either image, cached image_id, or embeddings" |
|
) |
|
elif image_id and not image and image_id not in self.embedding_cache: |
|
raise ValueError( |
|
f"Image ID {image_id} not in embedding cache, must provide the image or embeddings" |
|
) |
|
embedding, original_image_size = self.embed_image( |
|
image=image, image_id=image_id |
|
) |
|
else: |
|
if not orig_im_size: |
|
raise ValueError( |
|
"Must provide original image size if providing embeddings" |
|
) |
|
original_image_size = orig_im_size |
|
if embeddings_format == "json": |
|
embedding = np.array(embeddings) |
|
elif embeddings_format == "binary": |
|
embedding = np.load(BytesIO(embeddings)) |
|
|
|
point_coords = point_coords |
|
point_coords.append([0, 0]) |
|
point_coords = np.array(point_coords, dtype=np.float32) |
|
point_coords = np.expand_dims(point_coords, axis=0) |
|
point_coords = self.predictor.transform.apply_coords( |
|
point_coords, |
|
original_image_size, |
|
) |
|
|
|
point_labels = point_labels |
|
point_labels.append(-1) |
|
point_labels = np.array(point_labels, dtype=np.float32) |
|
point_labels = np.expand_dims(point_labels, axis=0) |
|
|
|
if has_mask_input: |
|
if ( |
|
image_id |
|
and image_id in self.low_res_logits_cache |
|
and use_mask_input_cache |
|
): |
|
mask_input = self.low_res_logits_cache[image_id] |
|
elif not mask_input and ( |
|
not image_id or image_id not in self.low_res_logits_cache |
|
): |
|
raise ValueError("Must provide either mask_input or cached image_id") |
|
else: |
|
if mask_input_format == "json": |
|
polys = mask_input |
|
mask_input = np.zeros((1, len(polys), 256, 256), dtype=np.uint8) |
|
for i, poly in enumerate(polys): |
|
poly = ShapelyPolygon(poly) |
|
raster = rasterio.features.rasterize( |
|
[poly], out_shape=(256, 256) |
|
) |
|
mask_input[0, i, :, :] = raster |
|
elif mask_input_format == "binary": |
|
binary_data = base64.b64decode(mask_input) |
|
mask_input = np.load(BytesIO(binary_data)) |
|
else: |
|
mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) |
|
|
|
ort_inputs = { |
|
"image_embeddings": embedding.astype(np.float32), |
|
"point_coords": point_coords.astype(np.float32), |
|
"point_labels": point_labels, |
|
"mask_input": mask_input.astype(np.float32), |
|
"has_mask_input": ( |
|
np.zeros(1, dtype=np.float32) |
|
if not has_mask_input |
|
else np.ones(1, dtype=np.float32) |
|
), |
|
"orig_im_size": np.array(original_image_size, dtype=np.float32), |
|
} |
|
masks, _, low_res_logits = self.ort_session.run(None, ort_inputs) |
|
if image_id: |
|
self.low_res_logits_cache[image_id] = low_res_logits |
|
if image_id not in self.segmentation_cache_keys: |
|
self.segmentation_cache_keys.append(image_id) |
|
if len(self.segmentation_cache_keys) > SAM_MAX_EMBEDDING_CACHE_SIZE: |
|
cache_key = self.segmentation_cache_keys.pop(0) |
|
del self.low_res_logits_cache[cache_key] |
|
masks = masks[0] |
|
low_res_masks = low_res_logits[0] |
|
|
|
return masks, low_res_masks |
|
|