방재호
init
5636c1c
raw
history blame
7.45 kB
import os
import gc
import cv2
import copy
import torch
from collections import OrderedDict
from modules import scripts, shared
from modules.devices import device, torch_gc, cpu
import local_groundingdino
dino_model_cache = OrderedDict()
sam_extension_dir = scripts.basedir()
dino_model_dir = os.path.join(sam_extension_dir, "models/grounding-dino")
dino_model_list = ["GroundingDINO_SwinT_OGC (694MB)", "GroundingDINO_SwinB (938MB)"]
dino_model_info = {
"GroundingDINO_SwinT_OGC (694MB)": {
"checkpoint": "groundingdino_swint_ogc.pth",
"config": os.path.join(dino_model_dir, "GroundingDINO_SwinT_OGC.py"),
"url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth",
},
"GroundingDINO_SwinB (938MB)": {
"checkpoint": "groundingdino_swinb_cogcoor.pth",
"config": os.path.join(dino_model_dir, "GroundingDINO_SwinB.cfg.py"),
"url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth"
},
}
dino_install_issue_text = "permanently switch to local groundingdino on Settings/Segment Anything or submit an issue to https://github.com/IDEA-Research/Grounded-Segment-Anything/issues."
def install_goundingdino():
if shared.opts.data.get("sam_use_local_groundingdino", False):
print("Using local groundingdino.")
return False
def verify_dll(install_local=True):
try:
from groundingdino import _C
print("GroundingDINO dynamic library have been successfully built.")
return True
except Exception:
import traceback
traceback.print_exc()
def run_pip_uninstall(command, desc=None):
from launch import python, run
default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
return run(f'"{python}" -m pip uninstall -y {command}', desc=f"Uninstalling {desc}", errdesc=f"Couldn't uninstall {desc}", live=default_command_live)
if install_local:
print(f"Failed to build dymanic library. Will uninstall GroundingDINO from pip and fall back to local groundingdino this time. Please {dino_install_issue_text}")
run_pip_uninstall(
f"groundingdino",
f"sd-webui-segment-anything requirement: groundingdino")
else:
print(f"Failed to build dymanic library. Will uninstall GroundingDINO from pip and re-try installing from GitHub source code. Please {dino_install_issue_text}")
run_pip_uninstall(
f"uninstall groundingdino",
f"sd-webui-segment-anything requirement: groundingdino")
return False
import launch
if launch.is_installed("groundingdino"):
print("Found GroundingDINO in pip. Verifying if dynamic library build success.")
if verify_dll(install_local=False):
return True
try:
launch.run_pip(
f"install git+https://github.com/IDEA-Research/GroundingDINO",
f"sd-webui-segment-anything requirement: groundingdino")
print("GroundingDINO install success. Verifying if dynamic library build success.")
return verify_dll()
except Exception:
import traceback
traceback.print_exc()
print(f"GroundingDINO install failed. Will fall back to local groundingdino this time. Please {dino_install_issue_text}")
return False
def show_boxes(image_np, boxes, color=(255, 0, 0, 255), thickness=2, show_index=False):
if boxes is None:
return image_np
image = copy.deepcopy(image_np)
for idx, box in enumerate(boxes):
x, y, w, h = box
cv2.rectangle(image, (x, y), (w, h), color, thickness)
if show_index:
font = cv2.FONT_HERSHEY_SIMPLEX
text = str(idx)
textsize = cv2.getTextSize(text, font, 1, 2)[0]
cv2.putText(image, text, (x, y+textsize[1]), font, 1, color, thickness)
return image
def clear_dino_cache():
dino_model_cache.clear()
gc.collect()
torch_gc()
def load_dino_model(dino_checkpoint, dino_install_success):
print(f"Initializing GroundingDINO {dino_checkpoint}")
if dino_checkpoint in dino_model_cache:
dino = dino_model_cache[dino_checkpoint]
if shared.cmd_opts.lowvram:
dino.to(device=device)
else:
clear_dino_cache()
if dino_install_success:
from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
else:
from local_groundingdino.models import build_model
from local_groundingdino.util.slconfig import SLConfig
from local_groundingdino.util.utils import clean_state_dict
args = SLConfig.fromfile(dino_model_info[dino_checkpoint]["config"])
dino = build_model(args)
checkpoint = torch.hub.load_state_dict_from_url(
dino_model_info[dino_checkpoint]["url"], dino_model_dir)
dino.load_state_dict(clean_state_dict(
checkpoint['model']), strict=False)
dino.to(device=device)
dino_model_cache[dino_checkpoint] = dino
dino.eval()
return dino
def load_dino_image(image_pil, dino_install_success):
if dino_install_success:
import groundingdino.datasets.transforms as T
else:
from local_groundingdino.datasets import transforms as T
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image, _ = transform(image_pil, None) # 3, h, w
return image
def get_grounding_output(model, image, caption, box_threshold):
caption = caption.lower()
caption = caption.strip()
if not caption.endswith("."):
caption = caption + "."
image = image.to(device)
with torch.no_grad():
outputs = model(image[None], captions=[caption])
if shared.cmd_opts.lowvram:
model.to(cpu)
logits = outputs["pred_logits"].sigmoid()[0] # (nq, 256)
boxes = outputs["pred_boxes"][0] # (nq, 4)
# filter output
logits_filt = logits.clone()
boxes_filt = boxes.clone()
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
logits_filt = logits_filt[filt_mask] # num_filt, 256
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
return boxes_filt.cpu()
def dino_predict_internal(input_image, dino_model_name, text_prompt, box_threshold):
install_success = install_goundingdino()
print("Running GroundingDINO Inference")
dino_image = load_dino_image(input_image.convert("RGB"), install_success)
dino_model = load_dino_model(dino_model_name, install_success)
install_success = install_success or shared.opts.data.get("sam_use_local_groundingdino", False)
boxes_filt = get_grounding_output(
dino_model, dino_image, text_prompt, box_threshold
)
H, W = input_image.size[1], input_image.size[0]
for i in range(boxes_filt.size(0)):
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
boxes_filt[i][2:] += boxes_filt[i][:2]
gc.collect()
torch_gc()
return boxes_filt, install_success