File size: 3,765 Bytes
7dbe662
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import functools
import PIL
from PIL.Image import Image
import numpy as np
from typing import List, Union
import supervision as sv

import torch
import torchvision

from huggingface_hub import hf_hub_download
from sam_extension.pipeline import Pipeline
from groundingdino.util.inference import Model

GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT_PATH = "groundingdino_swint_ogc.pth"
SAM_REPO_ID = 'YouLiXiya/YL-SAM'
LOCAL_DIR = "weights/groundingdino"
hf_sam_download = functools.partial(hf_hub_download, repo_id=SAM_REPO_ID, local_dir=LOCAL_DIR, local_dir_use_symlinks=True)
class GroundingDinoPipeline(Pipeline):
    def __init__(self,
                 grounding_dino_config_path,
                 grounfing_dino_ckpt_path,
                 grounding_dino_model,
                 device,
                 *args,
                 **kwargs):
        super(GroundingDinoPipeline, self).__init__(*args, **kwargs)
        self.grounding_dino_config_path = grounding_dino_config_path
        self.grounfing_dino_ckpt_path = grounfing_dino_ckpt_path
        self.grounding_dino_model = grounding_dino_model
        self.device = device


    @classmethod
    def from_pretrained(cls, grounding_dino_config_path, grounfing_dino_ckpt_path,device='cuda', *args, **kwargs):
        if not os.path.exists(grounfing_dino_ckpt_path):
            hf_sam_download(filename=os.path.basename(grounfing_dino_ckpt_path))
        grounding_dino_model = Model(model_config_path=grounding_dino_config_path,
                                     model_checkpoint_path=grounfing_dino_ckpt_path,
                                     device=device)
        return cls(grounding_dino_config_path,
                   grounfing_dino_ckpt_path,
                   grounding_dino_model,
                   device,
                   *args,
                   **kwargs)

    def visualize_results(self,
                          img: Union[Image, np.ndarray],
                          class_list: [List],
                          box_threshold: float=0.25,
                          text_threshold: float=0.25,
                          nms_threshold: float=0.8,
                          pil: bool=True):
        detections = self.forward(img, class_list, box_threshold, text_threshold)
        box_annotator = sv.BoxAnnotator()
        nms_idx = torchvision.ops.nms(
            torch.from_numpy(detections.xyxy),
            torch.from_numpy(detections.confidence),
            nms_threshold
        ).numpy().tolist()

        detections.xyxy = detections.xyxy[nms_idx]
        detections.confidence = detections.confidence[nms_idx]
        detections.class_id = detections.class_id[nms_idx]
        labels = [
            f"{class_list[class_id]} {confidence:0.2f}"
            for _, _, confidence, class_id, _
            in detections]
        annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections, labels=labels)
        if pil:
            return PIL.Image.fromarray(annotated_frame[:, :, ::-1]), detections
        else:
            return annotated_frame, detections


    @torch.no_grad()
    def forward(self,
                img: Union[Image, np.ndarray],
                class_list: [List],
                box_threshold: float=0.25,
                text_threshold: float=0.25
                )->sv.Detections:
        if isinstance(img, Image):
            img = np.uint8(img)[:, :, ::-1]
        detections = self.grounding_dino_model.predict_with_classes(
                    image=img,
                    classes=class_list,
                    box_threshold=box_threshold,
                    text_threshold=text_threshold
        )
        return detections