File size: 7,450 Bytes
b5ba7a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import os
import gc
import cv2
import copy
import torch
from collections import OrderedDict
from modules import scripts, shared
from modules.devices import device, torch_gc, cpu
import local_groundingdino
dino_model_cache = OrderedDict()
sam_extension_dir = scripts.basedir()
dino_model_dir = os.path.join(sam_extension_dir, "models/grounding-dino")
dino_model_list = ["GroundingDINO_SwinT_OGC (694MB)", "GroundingDINO_SwinB (938MB)"]
dino_model_info = {
"GroundingDINO_SwinT_OGC (694MB)": {
"checkpoint": "groundingdino_swint_ogc.pth",
"config": os.path.join(dino_model_dir, "GroundingDINO_SwinT_OGC.py"),
"url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth",
},
"GroundingDINO_SwinB (938MB)": {
"checkpoint": "groundingdino_swinb_cogcoor.pth",
"config": os.path.join(dino_model_dir, "GroundingDINO_SwinB.cfg.py"),
"url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth"
},
}
dino_install_issue_text = "permanently switch to local groundingdino on Settings/Segment Anything or submit an issue to https://github.com/IDEA-Research/Grounded-Segment-Anything/issues."
def install_goundingdino():
if shared.opts.data.get("sam_use_local_groundingdino", False):
print("Using local groundingdino.")
return False
def verify_dll(install_local=True):
try:
from groundingdino import _C
print("GroundingDINO dynamic library have been successfully built.")
return True
except Exception:
import traceback
traceback.print_exc()
def run_pip_uninstall(command, desc=None):
from launch import python, run
default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
return run(f'"{python}" -m pip uninstall -y {command}', desc=f"Uninstalling {desc}", errdesc=f"Couldn't uninstall {desc}", live=default_command_live)
if install_local:
print(f"Failed to build dymanic library. Will uninstall GroundingDINO from pip and fall back to local groundingdino this time. Please {dino_install_issue_text}")
run_pip_uninstall(
f"groundingdino",
f"sd-webui-segment-anything requirement: groundingdino")
else:
print(f"Failed to build dymanic library. Will uninstall GroundingDINO from pip and re-try installing from GitHub source code. Please {dino_install_issue_text}")
run_pip_uninstall(
f"uninstall groundingdino",
f"sd-webui-segment-anything requirement: groundingdino")
return False
import launch
if launch.is_installed("groundingdino"):
print("Found GroundingDINO in pip. Verifying if dynamic library build success.")
if verify_dll(install_local=False):
return True
try:
launch.run_pip(
f"install git+https://github.com/IDEA-Research/GroundingDINO",
f"sd-webui-segment-anything requirement: groundingdino")
print("GroundingDINO install success. Verifying if dynamic library build success.")
return verify_dll()
except Exception:
import traceback
traceback.print_exc()
print(f"GroundingDINO install failed. Will fall back to local groundingdino this time. Please {dino_install_issue_text}")
return False
def show_boxes(image_np, boxes, color=(255, 0, 0, 255), thickness=2, show_index=False):
if boxes is None:
return image_np
image = copy.deepcopy(image_np)
for idx, box in enumerate(boxes):
x, y, w, h = box
cv2.rectangle(image, (x, y), (w, h), color, thickness)
if show_index:
font = cv2.FONT_HERSHEY_SIMPLEX
text = str(idx)
textsize = cv2.getTextSize(text, font, 1, 2)[0]
cv2.putText(image, text, (x, y+textsize[1]), font, 1, color, thickness)
return image
def clear_dino_cache():
dino_model_cache.clear()
gc.collect()
torch_gc()
def load_dino_model(dino_checkpoint, dino_install_success):
print(f"Initializing GroundingDINO {dino_checkpoint}")
if dino_checkpoint in dino_model_cache:
dino = dino_model_cache[dino_checkpoint]
if shared.cmd_opts.lowvram:
dino.to(device=device)
else:
clear_dino_cache()
if dino_install_success:
from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
else:
from local_groundingdino.models import build_model
from local_groundingdino.util.slconfig import SLConfig
from local_groundingdino.util.utils import clean_state_dict
args = SLConfig.fromfile(dino_model_info[dino_checkpoint]["config"])
dino = build_model(args)
checkpoint = torch.hub.load_state_dict_from_url(
dino_model_info[dino_checkpoint]["url"], dino_model_dir)
dino.load_state_dict(clean_state_dict(
checkpoint['model']), strict=False)
dino.to(device=device)
dino_model_cache[dino_checkpoint] = dino
dino.eval()
return dino
def load_dino_image(image_pil, dino_install_success):
if dino_install_success:
import groundingdino.datasets.transforms as T
else:
from local_groundingdino.datasets import transforms as T
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image, _ = transform(image_pil, None) # 3, h, w
return image
def get_grounding_output(model, image, caption, box_threshold):
caption = caption.lower()
caption = caption.strip()
if not caption.endswith("."):
caption = caption + "."
image = image.to(device)
with torch.no_grad():
outputs = model(image[None], captions=[caption])
if shared.cmd_opts.lowvram:
model.to(cpu)
logits = outputs["pred_logits"].sigmoid()[0] # (nq, 256)
boxes = outputs["pred_boxes"][0] # (nq, 4)
# filter output
logits_filt = logits.clone()
boxes_filt = boxes.clone()
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
logits_filt = logits_filt[filt_mask] # num_filt, 256
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
return boxes_filt.cpu()
def dino_predict_internal(input_image, dino_model_name, text_prompt, box_threshold):
install_success = install_goundingdino()
print("Running GroundingDINO Inference")
dino_image = load_dino_image(input_image.convert("RGB"), install_success)
dino_model = load_dino_model(dino_model_name, install_success)
install_success = install_success or shared.opts.data.get("sam_use_local_groundingdino", False)
boxes_filt = get_grounding_output(
dino_model, dino_image, text_prompt, box_threshold
)
H, W = input_image.size[1], input_image.size[0]
for i in range(boxes_filt.size(0)):
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
boxes_filt[i][2:] += boxes_filt[i][:2]
gc.collect()
torch_gc()
return boxes_filt, install_success
|