File size: 5,537 Bytes
bf9dee2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import tempfile
from typing import Dict, List, Union
import numpy as np
from dds_cloudapi_sdk import (
    DetectionTask,
    Client,
    Config,
    TextPrompt,
    DetectionModel,
    DetectionTarget,
)
from PIL import Image
import concurrent.futures

class GroundingDINOAPIWrapper:
    """API wrapper for Grounding DINO 1.5

    Args:
        token (str): The token for Grounding DINO 1.5 API. We are now opening free API access to Grounding DINO 1.5. For
            educators, students, and researchers, we offer an API with extensive usage times to
            support your educational and research endeavors. You can get free API token at here:
            https://deepdataspace.com/request_api

    """

    def __init__(self, token: str):
        self.client = Client(Config(token=token))

    def inference(self, prompt: Dict, return_mask:bool=False):
        """Main inference function of Grounding DINO 1.5. We take batch as input and
        each image is a dict. N. We do not support batch inference for now.

        Args:
            prompts (dict): Annotations with the following keys:
                - "image" (str): Path to image. E.g. "test1.jpg",
                - "prompt" (str): Text prompt sepearted by '.' E.g. 'cate1 . cate2 . cate3'
            return_mask (bool): Whether to return mask. Defaults to False.

        Returns:
            (Dict): Detection results in dict format with keys::
                - "scores": (List[float]): A list of scores for each object in the batch
                - "labels": (List[int]): A list of labels for each object in the batch
                - "boxes": (List[List[int]]): A list of boxes for each object in the batch,
                     in format [xmin, ymin, xmax, ymax]
                - "masks": (List[np.ndarray]): A list of segmentations for each object in the batch
        """
        # construct input prompts
        image=self.get_image_url(prompt["image"]),
        task=DetectionTask(
            image_url=image[0],
            prompts=[TextPrompt(text=prompt['prompt'])],
            targets=[DetectionTarget.Mask, DetectionTarget.BBox] if return_mask else [DetectionTarget.BBox],
            model=DetectionModel.GDino1_5_Pro,
        )
        self.client.run_task(task)
        result = task.result
        return self.postprocess(result, task, return_mask)


    def postprocess(self, result, task, return_mask):
        """Postprocess the result from the API call

        Args:
            result (TaskResult): Task result with the following keys:
                - objects (List[DetectionObject]): Each DetectionObject has the following keys:
                    - bbox (List[float]): Box in xyxy format
                    - category (str): Detection category
                    - score (float): Detection score
                    - mask (DetectionObjectMask): Use mask.counts to parse RLE mask 
            task (DetectionTask): The task object
            return_mask (bool): Whether to return mask

        Returns:
            (Dict): Return dict in format:
                {
                    "scores": (List[float]): A list of scores for each object
                    "categorys": (List[str]): A list of categorys for each object
                    "boxes": (List[List[int]]): A list of boxes for each object
                    "masks": (List[PIL.Image]): A list of masks in the format of PIL.Image
                }
        """
        def process_object_with_mask(object):
            box = object.bbox
            score = object.score
            category = object.category
            mask = task.rle2rgba(object.mask)
            return box, score, category, mask
        
        def process_object_without_mask(object):
            box = object.bbox
            score = object.score
            category = object.category
            mask = None
            return box, score, category, mask
        
        boxes, scores, categorys, masks = [], [], [], []
        with concurrent.futures.ThreadPoolExecutor() as executor:
            if return_mask:
                process_object = process_object_with_mask
            else:
                process_object = process_object_without_mask
            futures = [executor.submit(process_object, obj) for obj in result.objects]
            for future in concurrent.futures.as_completed(futures):
                box, score, category, mask = future.result()
                boxes.append(box)
                scores.append(score)
                categorys.append(category)
                if mask is not None:
                    masks.append(mask)

        return dict(boxes=boxes, categorys=categorys, scores=scores, masks=masks)

    def get_image_url(self, image: Union[str, np.ndarray]):
        """Upload Image to server and return the url

        Args:
            image (Union[str, np.ndarray]): The image to upload. Can be a file path or np.ndarray.
                If it is a np.ndarray, it will be saved to a temporary file.

        Returns:
            str: The url of the image
        """
        if isinstance(image, str):
            url = self.client.upload_file(image)
        else:
            with tempfile.NamedTemporaryFile(delete=True, suffix=".png") as tmp_file:
                # image is in numpy format, convert to PIL Image
                image = Image.fromarray(image)
                image.save(tmp_file, format="PNG")
                tmp_file_path = tmp_file.name
                url = self.client.upload_file(tmp_file_path)
        return url