ShilongLiu commited on
Commit
4dc6d69
·
1 Parent(s): 54125c1

update add.py

Browse files
Files changed (1) hide show
  1. app.py +57 -36
app.py CHANGED
@@ -7,8 +7,9 @@ os.system("python -m pip install -e GroundingDINO")
7
  os.system("pip install --upgrade diffusers[torch]")
8
  os.system("pip install opencv-python pycocotools matplotlib onnxruntime onnx ipykernel")
9
  os.system("wget https://github.com/IDEA-Research/Grounded-Segment-Anything/raw/main/assets/demo1.jpg")
10
- os.system("wget https://dl.fbaipublicfiles.com/segment-anything/sam_vit_h_4b8939.pth")
11
  sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))
 
12
  warnings.filterwarnings("ignore")
13
 
14
  import gradio as gr
@@ -39,11 +40,13 @@ from transformers import BlipProcessor, BlipForConditionalGeneration
39
 
40
  def generate_caption(processor, blip_model, raw_image):
41
  # unconditional image captioning
42
- inputs = processor(raw_image, return_tensors="pt").to("cuda", torch.float16)
 
43
  out = blip_model.generate(**inputs)
44
  caption = processor.decode(out[0], skip_special_tokens=True)
45
  return caption
46
 
 
47
  def transform_image(image_pil):
48
 
49
  transform = T.Compose(
@@ -62,7 +65,8 @@ def load_model(model_config_path, model_checkpoint_path, device):
62
  args.device = device
63
  model = build_model(args)
64
  checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
65
- load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
 
66
  print(load_res)
67
  _ = model.eval()
68
  return model
@@ -95,18 +99,22 @@ def get_grounding_output(model, image, caption, box_threshold, text_threshold, w
95
  pred_phrases = []
96
  scores = []
97
  for logit, box in zip(logits_filt, boxes_filt):
98
- pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
 
99
  if with_logits:
100
- pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
 
101
  else:
102
  pred_phrases.append(pred_phrase)
103
  scores.append(logit.max().item())
104
 
105
  return boxes_filt, torch.Tensor(scores), pred_phrases
106
 
 
107
  def draw_mask(mask, draw, random_color=False):
108
  if random_color:
109
- color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 153)
 
110
  else:
111
  color = (30, 144, 255, 153)
112
 
@@ -115,11 +123,13 @@ def draw_mask(mask, draw, random_color=False):
115
  for coord in nonzero_coords:
116
  draw.point(coord[::-1], fill=color)
117
 
 
118
  def draw_box(box, draw, label):
119
  # random color
120
  color = tuple(np.random.randint(0, 255, size=3).tolist())
121
 
122
- draw.rectangle(((box[0], box[1]), (box[2], box[3])), outline=color, width=2)
 
123
 
124
  if label:
125
  font = ImageFont.load_default()
@@ -134,13 +144,12 @@ def draw_box(box, draw, label):
134
  draw.text((box[0], box[1]), label)
135
 
136
 
137
-
138
  config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
139
  ckpt_repo_id = "ShilongLiu/GroundingDINO"
140
  ckpt_filenmae = "groundingdino_swint_ogc.pth"
141
- sam_checkpoint='sam_vit_h_4b8939.pth'
142
- output_dir="outputs"
143
- device="cuda"
144
 
145
 
146
  blip_processor = None
@@ -149,6 +158,7 @@ groundingdino_model = None
149
  sam_predictor = None
150
  inpaint_pipeline = None
151
 
 
152
  def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode):
153
 
154
  global blip_processor, blip_model, groundingdino_model, sam_predictor, inpaint_pipeline
@@ -160,15 +170,18 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
160
  transformed_image = transform_image(image_pil)
161
 
162
  if groundingdino_model is None:
163
- groundingdino_model = load_model(config_file, ckpt_filenmae, device=device)
 
164
 
165
  if task_type == 'automatic':
166
  # generate caption and tags
167
  # use Tag2Text can generate better captions
168
  # https://huggingface.co/spaces/xinyu1205/Tag2Text
169
  # but there are some bugs...
170
- blip_processor = blip_processor or BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
171
- blip_model = blip_model or BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
 
 
172
  text_prompt = generate_caption(blip_processor, blip_model, image_pil)
173
  print(f"Caption: {text_prompt}")
174
 
@@ -188,7 +201,6 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
188
 
189
  boxes_filt = boxes_filt.cpu()
190
 
191
-
192
  if task_type == 'seg' or task_type == 'inpainting' or task_type == 'automatic':
193
  if sam_predictor is None:
194
  # initialize SAM
@@ -203,19 +215,21 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
203
  if task_type == 'automatic':
204
  # use NMS to handle overlapped boxes
205
  print(f"Before NMS: {boxes_filt.shape[0]} boxes")
206
- nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
 
207
  boxes_filt = boxes_filt[nms_idx]
208
  pred_phrases = [pred_phrases[idx] for idx in nms_idx]
209
  print(f"After NMS: {boxes_filt.shape[0]} boxes")
210
  print(f"Revise caption with number: {text_prompt}")
211
 
212
- transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
 
213
 
214
  masks, _, _ = sam_predictor.predict_torch(
215
- point_coords = None,
216
- point_labels = None,
217
- boxes = transformed_boxes,
218
- multimask_output = False,
219
  )
220
 
221
  # masks: [1, 1, 512, 512]
@@ -227,7 +241,7 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
227
 
228
  return [image_pil]
229
  elif task_type == 'seg' or task_type == 'automatic':
230
-
231
  mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
232
 
233
  mask_draw = ImageDraw.Draw(mask_image)
@@ -251,27 +265,32 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
251
  if inpaint_mode == 'merge':
252
  masks = torch.sum(masks, dim=0).unsqueeze(0)
253
  masks = torch.where(masks > 0, True, False)
254
- mask = masks[0][0].cpu().numpy() # simply choose the first mask, which will be refine in the future release
 
255
  mask_pil = Image.fromarray(mask)
256
-
257
  if inpaint_pipeline is None:
258
  inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
259
- "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16
260
  )
261
  inpaint_pipeline = inpaint_pipeline.to("cuda")
262
 
263
- image = inpaint_pipeline(prompt=inpaint_prompt, image=image_pil.resize((512, 512)), mask_image=mask_pil.resize((512, 512))).images[0]
 
264
  image = image.resize(size)
265
 
266
  return [image, mask_pil]
267
  else:
268
  print("task_type:{} error!".format(task_type))
269
 
 
270
  if __name__ == "__main__":
271
  parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
272
- parser.add_argument("--debug", action="store_true", help="using debug mode")
 
273
  parser.add_argument("--share", action="store_true", help="share the app")
274
- parser.add_argument('--no-gradio-queue', action="store_true", help='path to the SAM checkpoint')
 
275
  args = parser.parse_args()
276
 
277
  print(args)
@@ -283,10 +302,12 @@ if __name__ == "__main__":
283
  with block:
284
  with gr.Row():
285
  with gr.Column():
286
- input_image = gr.Image(source='upload', type="pil", value="demo1.jpg")
287
- task_type = gr.Dropdown(["det", "seg", "inpainting", "automatic"], value="automatic", label="task_type")
288
- text_prompt = gr.Textbox(label="Text Prompt")
289
- inpaint_prompt = gr.Textbox(label="Inpaint Prompt")
 
 
290
  run_button = gr.Button(label="Run")
291
  with gr.Accordion("Advanced options", open=False):
292
  box_threshold = gr.Slider(
@@ -298,7 +319,8 @@ if __name__ == "__main__":
298
  iou_threshold = gr.Slider(
299
  label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.001
300
  )
301
- inpaint_mode = gr.Dropdown(["merge", "first"], value="merge", label="inpaint_mode")
 
302
 
303
  with gr.Column():
304
  gallery = gr.Gallery(
@@ -306,7 +328,6 @@ if __name__ == "__main__":
306
  ).style(preview=True, grid=2, object_fit="scale-down")
307
 
308
  run_button.click(fn=run_grounded_sam, inputs=[
309
- input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode], outputs=gallery)
310
-
311
 
312
- block.launch(debug=args.debug, share=args.share, show_error=True)
 
7
  os.system("pip install --upgrade diffusers[torch]")
8
  os.system("pip install opencv-python pycocotools matplotlib onnxruntime onnx ipykernel")
9
  os.system("wget https://github.com/IDEA-Research/Grounded-Segment-Anything/raw/main/assets/demo1.jpg")
10
+ os.system("wget https://huggingface.co/spaces/mrtlive/segment-anything-model/resolve/main/sam_vit_h_4b8939.pth")
11
  sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))
12
+ sys.path.append(os.path.join(os.getcwd(), "segment_anything"))
13
  warnings.filterwarnings("ignore")
14
 
15
  import gradio as gr
 
40
 
41
  def generate_caption(processor, blip_model, raw_image):
42
  # unconditional image captioning
43
+ inputs = processor(raw_image, return_tensors="pt").to(
44
+ "cuda", torch.float16)
45
  out = blip_model.generate(**inputs)
46
  caption = processor.decode(out[0], skip_special_tokens=True)
47
  return caption
48
 
49
+
50
  def transform_image(image_pil):
51
 
52
  transform = T.Compose(
 
65
  args.device = device
66
  model = build_model(args)
67
  checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
68
+ load_res = model.load_state_dict(
69
+ clean_state_dict(checkpoint["model"]), strict=False)
70
  print(load_res)
71
  _ = model.eval()
72
  return model
 
99
  pred_phrases = []
100
  scores = []
101
  for logit, box in zip(logits_filt, boxes_filt):
102
+ pred_phrase = get_phrases_from_posmap(
103
+ logit > text_threshold, tokenized, tokenlizer)
104
  if with_logits:
105
+ pred_phrases.append(
106
+ pred_phrase + f"({str(logit.max().item())[:4]})")
107
  else:
108
  pred_phrases.append(pred_phrase)
109
  scores.append(logit.max().item())
110
 
111
  return boxes_filt, torch.Tensor(scores), pred_phrases
112
 
113
+
114
  def draw_mask(mask, draw, random_color=False):
115
  if random_color:
116
+ color = (random.randint(0, 255), random.randint(
117
+ 0, 255), random.randint(0, 255), 153)
118
  else:
119
  color = (30, 144, 255, 153)
120
 
 
123
  for coord in nonzero_coords:
124
  draw.point(coord[::-1], fill=color)
125
 
126
+
127
  def draw_box(box, draw, label):
128
  # random color
129
  color = tuple(np.random.randint(0, 255, size=3).tolist())
130
 
131
+ draw.rectangle(((box[0], box[1]), (box[2], box[3])),
132
+ outline=color, width=2)
133
 
134
  if label:
135
  font = ImageFont.load_default()
 
144
  draw.text((box[0], box[1]), label)
145
 
146
 
 
147
  config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
148
  ckpt_repo_id = "ShilongLiu/GroundingDINO"
149
  ckpt_filenmae = "groundingdino_swint_ogc.pth"
150
+ sam_checkpoint = 'sam_vit_h_4b8939.pth'
151
+ output_dir = "outputs"
152
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
153
 
154
 
155
  blip_processor = None
 
158
  sam_predictor = None
159
  inpaint_pipeline = None
160
 
161
+
162
  def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode):
163
 
164
  global blip_processor, blip_model, groundingdino_model, sam_predictor, inpaint_pipeline
 
170
  transformed_image = transform_image(image_pil)
171
 
172
  if groundingdino_model is None:
173
+ groundingdino_model = load_model(
174
+ config_file, ckpt_filenmae, device=device)
175
 
176
  if task_type == 'automatic':
177
  # generate caption and tags
178
  # use Tag2Text can generate better captions
179
  # https://huggingface.co/spaces/xinyu1205/Tag2Text
180
  # but there are some bugs...
181
+ blip_processor = blip_processor or BlipProcessor.from_pretrained(
182
+ "Salesforce/blip-image-captioning-large")
183
+ blip_model = blip_model or BlipForConditionalGeneration.from_pretrained(
184
+ "Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
185
  text_prompt = generate_caption(blip_processor, blip_model, image_pil)
186
  print(f"Caption: {text_prompt}")
187
 
 
201
 
202
  boxes_filt = boxes_filt.cpu()
203
 
 
204
  if task_type == 'seg' or task_type == 'inpainting' or task_type == 'automatic':
205
  if sam_predictor is None:
206
  # initialize SAM
 
215
  if task_type == 'automatic':
216
  # use NMS to handle overlapped boxes
217
  print(f"Before NMS: {boxes_filt.shape[0]} boxes")
218
+ nms_idx = torchvision.ops.nms(
219
+ boxes_filt, scores, iou_threshold).numpy().tolist()
220
  boxes_filt = boxes_filt[nms_idx]
221
  pred_phrases = [pred_phrases[idx] for idx in nms_idx]
222
  print(f"After NMS: {boxes_filt.shape[0]} boxes")
223
  print(f"Revise caption with number: {text_prompt}")
224
 
225
+ transformed_boxes = sam_predictor.transform.apply_boxes_torch(
226
+ boxes_filt, image.shape[:2]).to(device)
227
 
228
  masks, _, _ = sam_predictor.predict_torch(
229
+ point_coords=None,
230
+ point_labels=None,
231
+ boxes=transformed_boxes,
232
+ multimask_output=False,
233
  )
234
 
235
  # masks: [1, 1, 512, 512]
 
241
 
242
  return [image_pil]
243
  elif task_type == 'seg' or task_type == 'automatic':
244
+
245
  mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
246
 
247
  mask_draw = ImageDraw.Draw(mask_image)
 
265
  if inpaint_mode == 'merge':
266
  masks = torch.sum(masks, dim=0).unsqueeze(0)
267
  masks = torch.where(masks > 0, True, False)
268
+ # simply choose the first mask, which will be refine in the future release
269
+ mask = masks[0][0].cpu().numpy()
270
  mask_pil = Image.fromarray(mask)
271
+
272
  if inpaint_pipeline is None:
273
  inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
274
+ "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16
275
  )
276
  inpaint_pipeline = inpaint_pipeline.to("cuda")
277
 
278
+ image = inpaint_pipeline(prompt=inpaint_prompt, image=image_pil.resize(
279
+ (512, 512)), mask_image=mask_pil.resize((512, 512))).images[0]
280
  image = image.resize(size)
281
 
282
  return [image, mask_pil]
283
  else:
284
  print("task_type:{} error!".format(task_type))
285
 
286
+
287
  if __name__ == "__main__":
288
  parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
289
+ parser.add_argument("--debug", action="store_true",
290
+ help="using debug mode")
291
  parser.add_argument("--share", action="store_true", help="share the app")
292
+ parser.add_argument('--no-gradio-queue', action="store_true",
293
+ help='path to the SAM checkpoint')
294
  args = parser.parse_args()
295
 
296
  print(args)
 
302
  with block:
303
  with gr.Row():
304
  with gr.Column():
305
+ input_image = gr.Image(
306
+ source='upload', type="pil", value="demo1.jpg")
307
+ task_type = gr.Dropdown(
308
+ ["det", "seg", "inpainting", "automatic"], value="automatic", label="task_type")
309
+ text_prompt = gr.Textbox(label="Text Prompt", label="categories (separated by .)")
310
+ inpaint_prompt = gr.Textbox(label="Inpaint Prompt", label="The new image should be...")
311
  run_button = gr.Button(label="Run")
312
  with gr.Accordion("Advanced options", open=False):
313
  box_threshold = gr.Slider(
 
319
  iou_threshold = gr.Slider(
320
  label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.001
321
  )
322
+ inpaint_mode = gr.Dropdown(
323
+ ["merge", "first"], value="merge", label="inpaint_mode")
324
 
325
  with gr.Column():
326
  gallery = gr.Gallery(
 
328
  ).style(preview=True, grid=2, object_fit="scale-down")
329
 
330
  run_button.click(fn=run_grounded_sam, inputs=[
331
+ input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode], outputs=gallery)
 
332
 
333
+ block.launch(debug=args.debug, share=args.share, show_error=True)