Spaces:
Running
Running
File size: 3,765 Bytes
7dbe662 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import os
import functools
import PIL
from PIL.Image import Image
import numpy as np
from typing import List, Union
import supervision as sv
import torch
import torchvision
from huggingface_hub import hf_hub_download
from sam_extension.pipeline import Pipeline
from groundingdino.util.inference import Model
GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT_PATH = "groundingdino_swint_ogc.pth"
SAM_REPO_ID = 'YouLiXiya/YL-SAM'
LOCAL_DIR = "weights/groundingdino"
hf_sam_download = functools.partial(hf_hub_download, repo_id=SAM_REPO_ID, local_dir=LOCAL_DIR, local_dir_use_symlinks=True)
class GroundingDinoPipeline(Pipeline):
def __init__(self,
grounding_dino_config_path,
grounfing_dino_ckpt_path,
grounding_dino_model,
device,
*args,
**kwargs):
super(GroundingDinoPipeline, self).__init__(*args, **kwargs)
self.grounding_dino_config_path = grounding_dino_config_path
self.grounfing_dino_ckpt_path = grounfing_dino_ckpt_path
self.grounding_dino_model = grounding_dino_model
self.device = device
@classmethod
def from_pretrained(cls, grounding_dino_config_path, grounfing_dino_ckpt_path,device='cuda', *args, **kwargs):
if not os.path.exists(grounfing_dino_ckpt_path):
hf_sam_download(filename=os.path.basename(grounfing_dino_ckpt_path))
grounding_dino_model = Model(model_config_path=grounding_dino_config_path,
model_checkpoint_path=grounfing_dino_ckpt_path,
device=device)
return cls(grounding_dino_config_path,
grounfing_dino_ckpt_path,
grounding_dino_model,
device,
*args,
**kwargs)
def visualize_results(self,
img: Union[Image, np.ndarray],
class_list: [List],
box_threshold: float=0.25,
text_threshold: float=0.25,
nms_threshold: float=0.8,
pil: bool=True):
detections = self.forward(img, class_list, box_threshold, text_threshold)
box_annotator = sv.BoxAnnotator()
nms_idx = torchvision.ops.nms(
torch.from_numpy(detections.xyxy),
torch.from_numpy(detections.confidence),
nms_threshold
).numpy().tolist()
detections.xyxy = detections.xyxy[nms_idx]
detections.confidence = detections.confidence[nms_idx]
detections.class_id = detections.class_id[nms_idx]
labels = [
f"{class_list[class_id]} {confidence:0.2f}"
for _, _, confidence, class_id, _
in detections]
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections, labels=labels)
if pil:
return PIL.Image.fromarray(annotated_frame[:, :, ::-1]), detections
else:
return annotated_frame, detections
@torch.no_grad()
def forward(self,
img: Union[Image, np.ndarray],
class_list: [List],
box_threshold: float=0.25,
text_threshold: float=0.25
)->sv.Detections:
if isinstance(img, Image):
img = np.uint8(img)[:, :, ::-1]
detections = self.grounding_dino_model.predict_with_classes(
image=img,
classes=class_list,
box_threshold=box_threshold,
text_threshold=text_threshold
)
return detections
|