import tempfile from typing import Dict, List, Union import numpy as np from dds_cloudapi_sdk import ( DetectionTask, Client, Config, TextPrompt, DetectionModel, DetectionTarget, ) from PIL import Image import concurrent.futures class GroundingDINOAPIWrapper: """API wrapper for Grounding DINO 1.5 Args: token (str): The token for Grounding DINO 1.5 API. We are now opening free API access to Grounding DINO 1.5. For educators, students, and researchers, we offer an API with extensive usage times to support your educational and research endeavors. You can get free API token at here: https://deepdataspace.com/request_api """ def __init__(self, token: str): self.client = Client(Config(token=token)) def inference(self, prompt: Dict, return_mask:bool=False): """Main inference function of Grounding DINO 1.5. We take batch as input and each image is a dict. N. We do not support batch inference for now. Args: prompts (dict): Annotations with the following keys: - "image" (str): Path to image. E.g. "test1.jpg", - "prompt" (str): Text prompt sepearted by '.' E.g. 'cate1 . cate2 . cate3' return_mask (bool): Whether to return mask. Defaults to False. Returns: (Dict): Detection results in dict format with keys:: - "scores": (List[float]): A list of scores for each object in the batch - "labels": (List[int]): A list of labels for each object in the batch - "boxes": (List[List[int]]): A list of boxes for each object in the batch, in format [xmin, ymin, xmax, ymax] - "masks": (List[np.ndarray]): A list of segmentations for each object in the batch """ # construct input prompts image=self.get_image_url(prompt["image"]), task=DetectionTask( image_url=image[0], prompts=[TextPrompt(text=prompt['prompt'])], targets=[DetectionTarget.Mask, DetectionTarget.BBox] if return_mask else [DetectionTarget.BBox], model=DetectionModel.GDino1_5_Pro, ) self.client.run_task(task) result = task.result return self.postprocess(result, task, return_mask) def postprocess(self, result, task, return_mask): """Postprocess the result from the API call Args: result (TaskResult): Task result with the following keys: - objects (List[DetectionObject]): Each DetectionObject has the following keys: - bbox (List[float]): Box in xyxy format - category (str): Detection category - score (float): Detection score - mask (DetectionObjectMask): Use mask.counts to parse RLE mask task (DetectionTask): The task object return_mask (bool): Whether to return mask Returns: (Dict): Return dict in format: { "scores": (List[float]): A list of scores for each object "categorys": (List[str]): A list of categorys for each object "boxes": (List[List[int]]): A list of boxes for each object "masks": (List[PIL.Image]): A list of masks in the format of PIL.Image } """ def process_object_with_mask(object): box = object.bbox score = object.score category = object.category mask = task.rle2rgba(object.mask) return box, score, category, mask def process_object_without_mask(object): box = object.bbox score = object.score category = object.category mask = None return box, score, category, mask boxes, scores, categorys, masks = [], [], [], [] with concurrent.futures.ThreadPoolExecutor() as executor: if return_mask: process_object = process_object_with_mask else: process_object = process_object_without_mask futures = [executor.submit(process_object, obj) for obj in result.objects] for future in concurrent.futures.as_completed(futures): box, score, category, mask = future.result() boxes.append(box) scores.append(score) categorys.append(category) if mask is not None: masks.append(mask) return dict(boxes=boxes, categorys=categorys, scores=scores, masks=masks) def get_image_url(self, image: Union[str, np.ndarray]): """Upload Image to server and return the url Args: image (Union[str, np.ndarray]): The image to upload. Can be a file path or np.ndarray. If it is a np.ndarray, it will be saved to a temporary file. Returns: str: The url of the image """ if isinstance(image, str): url = self.client.upload_file(image) else: with tempfile.NamedTemporaryFile(delete=True, suffix=".png") as tmp_file: # image is in numpy format, convert to PIL Image image = Image.fromarray(image) image.save(tmp_file, format="PNG") tmp_file_path = tmp_file.name url = self.client.upload_file(tmp_file_path) return url