File size: 3,594 Bytes
538b6a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from typing import Tuple, List

import ldm_patched.modules.model_management as model_management
from ldm_patched.modules.model_patcher import ModelPatcher
from modules.config import path_inpaint
from modules.model_loader import load_file_from_url

import numpy as np
import supervision as sv
import torch
from groundingdino.util.inference import Model
from groundingdino.util.inference import load_model, preprocess_caption, get_phrases_from_posmap


class GroundingDinoModel(Model):
    def __init__(self):
        self.config_file = 'extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py'
        self.model = None
        self.load_device = torch.device('cpu')
        self.offload_device = torch.device('cpu')

    @torch.no_grad()
    @torch.inference_mode()
    def predict_with_caption(
            self,
            image: np.ndarray,
            caption: str,
            box_threshold: float = 0.35,
            text_threshold: float = 0.25
    ) -> Tuple[sv.Detections, torch.Tensor, torch.Tensor, List[str]]:
        if self.model is None:
            filename = load_file_from_url(
                url="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
                file_name='groundingdino_swint_ogc.pth',
                model_dir=path_inpaint)
            model = load_model(model_config_path=self.config_file, model_checkpoint_path=filename)

            self.load_device = model_management.text_encoder_device()
            self.offload_device = model_management.text_encoder_offload_device()

            model.to(self.offload_device)

            self.model = ModelPatcher(model, load_device=self.load_device, offload_device=self.offload_device)

        model_management.load_model_gpu(self.model)

        processed_image = GroundingDinoModel.preprocess_image(image_bgr=image).to(self.load_device)
        boxes, logits, phrases = predict(
            model=self.model,
            image=processed_image,
            caption=caption,
            box_threshold=box_threshold,
            text_threshold=text_threshold,
            device=self.load_device)
        source_h, source_w, _ = image.shape
        detections = GroundingDinoModel.post_process_result(
            source_h=source_h,
            source_w=source_w,
            boxes=boxes,
            logits=logits)
        return detections, boxes, logits, phrases


def predict(
        model,
        image: torch.Tensor,
        caption: str,
        box_threshold: float,
        text_threshold: float,
        device: str = "cuda"
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
    caption = preprocess_caption(caption=caption)

    # override to use model wrapped by patcher
    model = model.model.to(device)
    image = image.to(device)

    with torch.no_grad():
        outputs = model(image[None], captions=[caption])

    prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0]  # prediction_logits.shape = (nq, 256)
    prediction_boxes = outputs["pred_boxes"].cpu()[0]  # prediction_boxes.shape = (nq, 4)

    mask = prediction_logits.max(dim=1)[0] > box_threshold
    logits = prediction_logits[mask]  # logits.shape = (n, 256)
    boxes = prediction_boxes[mask]  # boxes.shape = (n, 4)

    tokenizer = model.tokenizer
    tokenized = tokenizer(caption)

    phrases = [
        get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '')
        for logit
        in logits
    ]

    return boxes, logits.max(dim=1)[0], phrases


default_groundingdino = GroundingDinoModel().predict_with_caption