Spaces:
No application file
No application file
from typing import Tuple, List | |
import ldm_patched.modules.model_management as model_management | |
from ldm_patched.modules.model_patcher import ModelPatcher | |
from modules.config import path_inpaint | |
from modules.model_loader import load_file_from_url | |
import numpy as np | |
import supervision as sv | |
import torch | |
from groundingdino.util.inference import Model | |
from groundingdino.util.inference import load_model, preprocess_caption, get_phrases_from_posmap | |
class GroundingDinoModel(Model): | |
def __init__(self): | |
self.config_file = 'extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py' | |
self.model = None | |
self.load_device = torch.device('cpu') | |
self.offload_device = torch.device('cpu') | |
def predict_with_caption( | |
self, | |
image: np.ndarray, | |
caption: str, | |
box_threshold: float = 0.35, | |
text_threshold: float = 0.25 | |
) -> Tuple[sv.Detections, torch.Tensor, torch.Tensor, List[str]]: | |
if self.model is None: | |
filename = load_file_from_url( | |
url="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth", | |
file_name='groundingdino_swint_ogc.pth', | |
model_dir=path_inpaint) | |
model = load_model(model_config_path=self.config_file, model_checkpoint_path=filename) | |
self.load_device = model_management.text_encoder_device() | |
self.offload_device = model_management.text_encoder_offload_device() | |
model.to(self.offload_device) | |
self.model = ModelPatcher(model, load_device=self.load_device, offload_device=self.offload_device) | |
model_management.load_model_gpu(self.model) | |
processed_image = GroundingDinoModel.preprocess_image(image_bgr=image).to(self.load_device) | |
boxes, logits, phrases = predict( | |
model=self.model, | |
image=processed_image, | |
caption=caption, | |
box_threshold=box_threshold, | |
text_threshold=text_threshold, | |
device=self.load_device) | |
source_h, source_w, _ = image.shape | |
detections = GroundingDinoModel.post_process_result( | |
source_h=source_h, | |
source_w=source_w, | |
boxes=boxes, | |
logits=logits) | |
return detections, boxes, logits, phrases | |
def predict( | |
model, | |
image: torch.Tensor, | |
caption: str, | |
box_threshold: float, | |
text_threshold: float, | |
device: str = "cuda" | |
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]: | |
caption = preprocess_caption(caption=caption) | |
# override to use model wrapped by patcher | |
model = model.model.to(device) | |
image = image.to(device) | |
with torch.no_grad(): | |
outputs = model(image[None], captions=[caption]) | |
prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256) | |
prediction_boxes = outputs["pred_boxes"].cpu()[0] # prediction_boxes.shape = (nq, 4) | |
mask = prediction_logits.max(dim=1)[0] > box_threshold | |
logits = prediction_logits[mask] # logits.shape = (n, 256) | |
boxes = prediction_boxes[mask] # boxes.shape = (n, 4) | |
tokenizer = model.tokenizer | |
tokenized = tokenizer(caption) | |
phrases = [ | |
get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '') | |
for logit | |
in logits | |
] | |
return boxes, logits.max(dim=1)[0], phrases | |
default_groundingdino = GroundingDinoModel().predict_with_caption | |