Spaces:
Runtime error
Runtime error
ShilongLiu
commited on
Commit
·
4dc6d69
1
Parent(s):
54125c1
update add.py
Browse files
app.py
CHANGED
@@ -7,8 +7,9 @@ os.system("python -m pip install -e GroundingDINO")
|
|
7 |
os.system("pip install --upgrade diffusers[torch]")
|
8 |
os.system("pip install opencv-python pycocotools matplotlib onnxruntime onnx ipykernel")
|
9 |
os.system("wget https://github.com/IDEA-Research/Grounded-Segment-Anything/raw/main/assets/demo1.jpg")
|
10 |
-
os.system("wget https://
|
11 |
sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))
|
|
|
12 |
warnings.filterwarnings("ignore")
|
13 |
|
14 |
import gradio as gr
|
@@ -39,11 +40,13 @@ from transformers import BlipProcessor, BlipForConditionalGeneration
|
|
39 |
|
40 |
def generate_caption(processor, blip_model, raw_image):
|
41 |
# unconditional image captioning
|
42 |
-
inputs = processor(raw_image, return_tensors="pt").to(
|
|
|
43 |
out = blip_model.generate(**inputs)
|
44 |
caption = processor.decode(out[0], skip_special_tokens=True)
|
45 |
return caption
|
46 |
|
|
|
47 |
def transform_image(image_pil):
|
48 |
|
49 |
transform = T.Compose(
|
@@ -62,7 +65,8 @@ def load_model(model_config_path, model_checkpoint_path, device):
|
|
62 |
args.device = device
|
63 |
model = build_model(args)
|
64 |
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
|
65 |
-
load_res = model.load_state_dict(
|
|
|
66 |
print(load_res)
|
67 |
_ = model.eval()
|
68 |
return model
|
@@ -95,18 +99,22 @@ def get_grounding_output(model, image, caption, box_threshold, text_threshold, w
|
|
95 |
pred_phrases = []
|
96 |
scores = []
|
97 |
for logit, box in zip(logits_filt, boxes_filt):
|
98 |
-
pred_phrase = get_phrases_from_posmap(
|
|
|
99 |
if with_logits:
|
100 |
-
pred_phrases.append(
|
|
|
101 |
else:
|
102 |
pred_phrases.append(pred_phrase)
|
103 |
scores.append(logit.max().item())
|
104 |
|
105 |
return boxes_filt, torch.Tensor(scores), pred_phrases
|
106 |
|
|
|
107 |
def draw_mask(mask, draw, random_color=False):
|
108 |
if random_color:
|
109 |
-
color = (random.randint(0, 255), random.randint(
|
|
|
110 |
else:
|
111 |
color = (30, 144, 255, 153)
|
112 |
|
@@ -115,11 +123,13 @@ def draw_mask(mask, draw, random_color=False):
|
|
115 |
for coord in nonzero_coords:
|
116 |
draw.point(coord[::-1], fill=color)
|
117 |
|
|
|
118 |
def draw_box(box, draw, label):
|
119 |
# random color
|
120 |
color = tuple(np.random.randint(0, 255, size=3).tolist())
|
121 |
|
122 |
-
draw.rectangle(((box[0], box[1]), (box[2], box[3])),
|
|
|
123 |
|
124 |
if label:
|
125 |
font = ImageFont.load_default()
|
@@ -134,13 +144,12 @@ def draw_box(box, draw, label):
|
|
134 |
draw.text((box[0], box[1]), label)
|
135 |
|
136 |
|
137 |
-
|
138 |
config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
|
139 |
ckpt_repo_id = "ShilongLiu/GroundingDINO"
|
140 |
ckpt_filenmae = "groundingdino_swint_ogc.pth"
|
141 |
-
sam_checkpoint='sam_vit_h_4b8939.pth'
|
142 |
-
output_dir="outputs"
|
143 |
-
device=
|
144 |
|
145 |
|
146 |
blip_processor = None
|
@@ -149,6 +158,7 @@ groundingdino_model = None
|
|
149 |
sam_predictor = None
|
150 |
inpaint_pipeline = None
|
151 |
|
|
|
152 |
def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode):
|
153 |
|
154 |
global blip_processor, blip_model, groundingdino_model, sam_predictor, inpaint_pipeline
|
@@ -160,15 +170,18 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
|
|
160 |
transformed_image = transform_image(image_pil)
|
161 |
|
162 |
if groundingdino_model is None:
|
163 |
-
groundingdino_model = load_model(
|
|
|
164 |
|
165 |
if task_type == 'automatic':
|
166 |
# generate caption and tags
|
167 |
# use Tag2Text can generate better captions
|
168 |
# https://huggingface.co/spaces/xinyu1205/Tag2Text
|
169 |
# but there are some bugs...
|
170 |
-
blip_processor = blip_processor or BlipProcessor.from_pretrained(
|
171 |
-
|
|
|
|
|
172 |
text_prompt = generate_caption(blip_processor, blip_model, image_pil)
|
173 |
print(f"Caption: {text_prompt}")
|
174 |
|
@@ -188,7 +201,6 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
|
|
188 |
|
189 |
boxes_filt = boxes_filt.cpu()
|
190 |
|
191 |
-
|
192 |
if task_type == 'seg' or task_type == 'inpainting' or task_type == 'automatic':
|
193 |
if sam_predictor is None:
|
194 |
# initialize SAM
|
@@ -203,19 +215,21 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
|
|
203 |
if task_type == 'automatic':
|
204 |
# use NMS to handle overlapped boxes
|
205 |
print(f"Before NMS: {boxes_filt.shape[0]} boxes")
|
206 |
-
nms_idx = torchvision.ops.nms(
|
|
|
207 |
boxes_filt = boxes_filt[nms_idx]
|
208 |
pred_phrases = [pred_phrases[idx] for idx in nms_idx]
|
209 |
print(f"After NMS: {boxes_filt.shape[0]} boxes")
|
210 |
print(f"Revise caption with number: {text_prompt}")
|
211 |
|
212 |
-
transformed_boxes = sam_predictor.transform.apply_boxes_torch(
|
|
|
213 |
|
214 |
masks, _, _ = sam_predictor.predict_torch(
|
215 |
-
point_coords
|
216 |
-
point_labels
|
217 |
-
boxes
|
218 |
-
multimask_output
|
219 |
)
|
220 |
|
221 |
# masks: [1, 1, 512, 512]
|
@@ -227,7 +241,7 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
|
|
227 |
|
228 |
return [image_pil]
|
229 |
elif task_type == 'seg' or task_type == 'automatic':
|
230 |
-
|
231 |
mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
|
232 |
|
233 |
mask_draw = ImageDraw.Draw(mask_image)
|
@@ -251,27 +265,32 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
|
|
251 |
if inpaint_mode == 'merge':
|
252 |
masks = torch.sum(masks, dim=0).unsqueeze(0)
|
253 |
masks = torch.where(masks > 0, True, False)
|
254 |
-
|
|
|
255 |
mask_pil = Image.fromarray(mask)
|
256 |
-
|
257 |
if inpaint_pipeline is None:
|
258 |
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
259 |
-
|
260 |
)
|
261 |
inpaint_pipeline = inpaint_pipeline.to("cuda")
|
262 |
|
263 |
-
image = inpaint_pipeline(prompt=inpaint_prompt, image=image_pil.resize(
|
|
|
264 |
image = image.resize(size)
|
265 |
|
266 |
return [image, mask_pil]
|
267 |
else:
|
268 |
print("task_type:{} error!".format(task_type))
|
269 |
|
|
|
270 |
if __name__ == "__main__":
|
271 |
parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
|
272 |
-
parser.add_argument("--debug", action="store_true",
|
|
|
273 |
parser.add_argument("--share", action="store_true", help="share the app")
|
274 |
-
parser.add_argument('--no-gradio-queue', action="store_true",
|
|
|
275 |
args = parser.parse_args()
|
276 |
|
277 |
print(args)
|
@@ -283,10 +302,12 @@ if __name__ == "__main__":
|
|
283 |
with block:
|
284 |
with gr.Row():
|
285 |
with gr.Column():
|
286 |
-
input_image = gr.Image(
|
287 |
-
|
288 |
-
|
289 |
-
|
|
|
|
|
290 |
run_button = gr.Button(label="Run")
|
291 |
with gr.Accordion("Advanced options", open=False):
|
292 |
box_threshold = gr.Slider(
|
@@ -298,7 +319,8 @@ if __name__ == "__main__":
|
|
298 |
iou_threshold = gr.Slider(
|
299 |
label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.001
|
300 |
)
|
301 |
-
inpaint_mode = gr.Dropdown(
|
|
|
302 |
|
303 |
with gr.Column():
|
304 |
gallery = gr.Gallery(
|
@@ -306,7 +328,6 @@ if __name__ == "__main__":
|
|
306 |
).style(preview=True, grid=2, object_fit="scale-down")
|
307 |
|
308 |
run_button.click(fn=run_grounded_sam, inputs=[
|
309 |
-
|
310 |
-
|
311 |
|
312 |
-
block.launch(debug=args.debug, share=args.share, show_error=True)
|
|
|
7 |
os.system("pip install --upgrade diffusers[torch]")
|
8 |
os.system("pip install opencv-python pycocotools matplotlib onnxruntime onnx ipykernel")
|
9 |
os.system("wget https://github.com/IDEA-Research/Grounded-Segment-Anything/raw/main/assets/demo1.jpg")
|
10 |
+
os.system("wget https://huggingface.co/spaces/mrtlive/segment-anything-model/resolve/main/sam_vit_h_4b8939.pth")
|
11 |
sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))
|
12 |
+
sys.path.append(os.path.join(os.getcwd(), "segment_anything"))
|
13 |
warnings.filterwarnings("ignore")
|
14 |
|
15 |
import gradio as gr
|
|
|
40 |
|
41 |
def generate_caption(processor, blip_model, raw_image):
|
42 |
# unconditional image captioning
|
43 |
+
inputs = processor(raw_image, return_tensors="pt").to(
|
44 |
+
"cuda", torch.float16)
|
45 |
out = blip_model.generate(**inputs)
|
46 |
caption = processor.decode(out[0], skip_special_tokens=True)
|
47 |
return caption
|
48 |
|
49 |
+
|
50 |
def transform_image(image_pil):
|
51 |
|
52 |
transform = T.Compose(
|
|
|
65 |
args.device = device
|
66 |
model = build_model(args)
|
67 |
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
|
68 |
+
load_res = model.load_state_dict(
|
69 |
+
clean_state_dict(checkpoint["model"]), strict=False)
|
70 |
print(load_res)
|
71 |
_ = model.eval()
|
72 |
return model
|
|
|
99 |
pred_phrases = []
|
100 |
scores = []
|
101 |
for logit, box in zip(logits_filt, boxes_filt):
|
102 |
+
pred_phrase = get_phrases_from_posmap(
|
103 |
+
logit > text_threshold, tokenized, tokenlizer)
|
104 |
if with_logits:
|
105 |
+
pred_phrases.append(
|
106 |
+
pred_phrase + f"({str(logit.max().item())[:4]})")
|
107 |
else:
|
108 |
pred_phrases.append(pred_phrase)
|
109 |
scores.append(logit.max().item())
|
110 |
|
111 |
return boxes_filt, torch.Tensor(scores), pred_phrases
|
112 |
|
113 |
+
|
114 |
def draw_mask(mask, draw, random_color=False):
|
115 |
if random_color:
|
116 |
+
color = (random.randint(0, 255), random.randint(
|
117 |
+
0, 255), random.randint(0, 255), 153)
|
118 |
else:
|
119 |
color = (30, 144, 255, 153)
|
120 |
|
|
|
123 |
for coord in nonzero_coords:
|
124 |
draw.point(coord[::-1], fill=color)
|
125 |
|
126 |
+
|
127 |
def draw_box(box, draw, label):
|
128 |
# random color
|
129 |
color = tuple(np.random.randint(0, 255, size=3).tolist())
|
130 |
|
131 |
+
draw.rectangle(((box[0], box[1]), (box[2], box[3])),
|
132 |
+
outline=color, width=2)
|
133 |
|
134 |
if label:
|
135 |
font = ImageFont.load_default()
|
|
|
144 |
draw.text((box[0], box[1]), label)
|
145 |
|
146 |
|
|
|
147 |
config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
|
148 |
ckpt_repo_id = "ShilongLiu/GroundingDINO"
|
149 |
ckpt_filenmae = "groundingdino_swint_ogc.pth"
|
150 |
+
sam_checkpoint = 'sam_vit_h_4b8939.pth'
|
151 |
+
output_dir = "outputs"
|
152 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
153 |
|
154 |
|
155 |
blip_processor = None
|
|
|
158 |
sam_predictor = None
|
159 |
inpaint_pipeline = None
|
160 |
|
161 |
+
|
162 |
def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode):
|
163 |
|
164 |
global blip_processor, blip_model, groundingdino_model, sam_predictor, inpaint_pipeline
|
|
|
170 |
transformed_image = transform_image(image_pil)
|
171 |
|
172 |
if groundingdino_model is None:
|
173 |
+
groundingdino_model = load_model(
|
174 |
+
config_file, ckpt_filenmae, device=device)
|
175 |
|
176 |
if task_type == 'automatic':
|
177 |
# generate caption and tags
|
178 |
# use Tag2Text can generate better captions
|
179 |
# https://huggingface.co/spaces/xinyu1205/Tag2Text
|
180 |
# but there are some bugs...
|
181 |
+
blip_processor = blip_processor or BlipProcessor.from_pretrained(
|
182 |
+
"Salesforce/blip-image-captioning-large")
|
183 |
+
blip_model = blip_model or BlipForConditionalGeneration.from_pretrained(
|
184 |
+
"Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
|
185 |
text_prompt = generate_caption(blip_processor, blip_model, image_pil)
|
186 |
print(f"Caption: {text_prompt}")
|
187 |
|
|
|
201 |
|
202 |
boxes_filt = boxes_filt.cpu()
|
203 |
|
|
|
204 |
if task_type == 'seg' or task_type == 'inpainting' or task_type == 'automatic':
|
205 |
if sam_predictor is None:
|
206 |
# initialize SAM
|
|
|
215 |
if task_type == 'automatic':
|
216 |
# use NMS to handle overlapped boxes
|
217 |
print(f"Before NMS: {boxes_filt.shape[0]} boxes")
|
218 |
+
nms_idx = torchvision.ops.nms(
|
219 |
+
boxes_filt, scores, iou_threshold).numpy().tolist()
|
220 |
boxes_filt = boxes_filt[nms_idx]
|
221 |
pred_phrases = [pred_phrases[idx] for idx in nms_idx]
|
222 |
print(f"After NMS: {boxes_filt.shape[0]} boxes")
|
223 |
print(f"Revise caption with number: {text_prompt}")
|
224 |
|
225 |
+
transformed_boxes = sam_predictor.transform.apply_boxes_torch(
|
226 |
+
boxes_filt, image.shape[:2]).to(device)
|
227 |
|
228 |
masks, _, _ = sam_predictor.predict_torch(
|
229 |
+
point_coords=None,
|
230 |
+
point_labels=None,
|
231 |
+
boxes=transformed_boxes,
|
232 |
+
multimask_output=False,
|
233 |
)
|
234 |
|
235 |
# masks: [1, 1, 512, 512]
|
|
|
241 |
|
242 |
return [image_pil]
|
243 |
elif task_type == 'seg' or task_type == 'automatic':
|
244 |
+
|
245 |
mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
|
246 |
|
247 |
mask_draw = ImageDraw.Draw(mask_image)
|
|
|
265 |
if inpaint_mode == 'merge':
|
266 |
masks = torch.sum(masks, dim=0).unsqueeze(0)
|
267 |
masks = torch.where(masks > 0, True, False)
|
268 |
+
# simply choose the first mask, which will be refine in the future release
|
269 |
+
mask = masks[0][0].cpu().numpy()
|
270 |
mask_pil = Image.fromarray(mask)
|
271 |
+
|
272 |
if inpaint_pipeline is None:
|
273 |
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
274 |
+
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16
|
275 |
)
|
276 |
inpaint_pipeline = inpaint_pipeline.to("cuda")
|
277 |
|
278 |
+
image = inpaint_pipeline(prompt=inpaint_prompt, image=image_pil.resize(
|
279 |
+
(512, 512)), mask_image=mask_pil.resize((512, 512))).images[0]
|
280 |
image = image.resize(size)
|
281 |
|
282 |
return [image, mask_pil]
|
283 |
else:
|
284 |
print("task_type:{} error!".format(task_type))
|
285 |
|
286 |
+
|
287 |
if __name__ == "__main__":
|
288 |
parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
|
289 |
+
parser.add_argument("--debug", action="store_true",
|
290 |
+
help="using debug mode")
|
291 |
parser.add_argument("--share", action="store_true", help="share the app")
|
292 |
+
parser.add_argument('--no-gradio-queue', action="store_true",
|
293 |
+
help='path to the SAM checkpoint')
|
294 |
args = parser.parse_args()
|
295 |
|
296 |
print(args)
|
|
|
302 |
with block:
|
303 |
with gr.Row():
|
304 |
with gr.Column():
|
305 |
+
input_image = gr.Image(
|
306 |
+
source='upload', type="pil", value="demo1.jpg")
|
307 |
+
task_type = gr.Dropdown(
|
308 |
+
["det", "seg", "inpainting", "automatic"], value="automatic", label="task_type")
|
309 |
+
text_prompt = gr.Textbox(label="Text Prompt", label="categories (separated by .)")
|
310 |
+
inpaint_prompt = gr.Textbox(label="Inpaint Prompt", label="The new image should be...")
|
311 |
run_button = gr.Button(label="Run")
|
312 |
with gr.Accordion("Advanced options", open=False):
|
313 |
box_threshold = gr.Slider(
|
|
|
319 |
iou_threshold = gr.Slider(
|
320 |
label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.001
|
321 |
)
|
322 |
+
inpaint_mode = gr.Dropdown(
|
323 |
+
["merge", "first"], value="merge", label="inpaint_mode")
|
324 |
|
325 |
with gr.Column():
|
326 |
gallery = gr.Gallery(
|
|
|
328 |
).style(preview=True, grid=2, object_fit="scale-down")
|
329 |
|
330 |
run_button.click(fn=run_grounded_sam, inputs=[
|
331 |
+
input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode], outputs=gallery)
|
|
|
332 |
|
333 |
+
block.launch(debug=args.debug, share=args.share, show_error=True)
|