Yw22 commited on
Commit
15d0e68
·
1 Parent(s): ff48bea

fix some bugs

Browse files
Files changed (3) hide show
  1. app/run_app.sh +5 -0
  2. app/src/brushedit_app.py +53 -42
  3. app/src/vlm_pipeline.py +24 -18
app/run_app.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ export PYTHONPATH=.:$PYTHONPATH
2
+
3
+ export CUDA_VISIBLE_DEVICES=0
4
+
5
+ python app/src/brushedit_app.py
app/src/brushedit_app.py CHANGED
@@ -337,7 +337,7 @@ vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_M
337
  if vlm_processor != "" and vlm_model != "":
338
  vlm_model.to(device)
339
  else:
340
- gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
341
 
342
 
343
  ## init base model
@@ -504,7 +504,7 @@ def random_mask_func(mask, dilation_type='square', dilation_size=20):
504
  dilated_mask = np.zeros_like(binary_mask, dtype=bool)
505
  dilated_mask[ellipse_mask] = True
506
  else:
507
- raise ValueError("dilation_type must be 'square' or 'ellipse'")
508
 
509
  # use binary dilation
510
  dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
@@ -637,7 +637,8 @@ def process(input_image,
637
  image_pil = input_image["background"].convert("RGB")
638
  original_image = np.array(image_pil)
639
  if prompt is None or prompt == "":
640
- raise gr.Error("Please input your instructions, e.g., remove the xxx")
 
641
 
642
  alpha_mask = input_image["layers"][0].split()[3]
643
  input_mask = np.asarray(alpha_mask)
@@ -687,17 +688,23 @@ def process(input_image,
687
  original_mask = input_mask
688
 
689
 
690
-
691
  if category is not None:
692
- pass
 
 
693
  else:
694
- category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
695
-
 
 
696
 
 
697
  if original_mask is not None:
698
  original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
699
  else:
700
- object_wait_for_edit = vlm_response_object_wait_for_edit(
 
701
  vlm_processor,
702
  vlm_model,
703
  original_image,
@@ -705,30 +712,37 @@ def process(input_image,
705
  prompt,
706
  device)
707
 
708
- original_mask = vlm_response_mask(vlm_processor,
709
- vlm_model,
710
- category,
711
- original_image,
712
- prompt,
713
- object_wait_for_edit,
714
- sam,
715
- sam_predictor,
716
- sam_automask_generator,
717
- groundingdino_model,
718
- device)
 
 
 
719
  if original_mask.ndim == 2:
720
  original_mask = original_mask[:,:,None]
721
 
722
 
723
- if len(target_prompt) <= 1:
724
- prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
 
 
 
 
725
  vlm_processor,
726
  vlm_model,
727
  original_image,
728
  prompt,
729
  device)
730
- else:
731
- prompt_after_apply_instruction = target_prompt
732
 
733
  generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
734
 
@@ -758,7 +772,8 @@ def process(input_image,
758
  # image[3].save(f"outputs/image_edit_{uuid}_3.png")
759
  # mask_image.save(f"outputs/mask_{uuid}.png")
760
  # masked_image.save(f"outputs/masked_image_{uuid}.png")
761
- return image, [mask_image], [masked_image], prompt, '', prompt_after_apply_instruction, False
 
762
 
763
 
764
  def generate_target_prompt(input_image,
@@ -774,7 +789,7 @@ def generate_target_prompt(input_image,
774
  original_image,
775
  prompt,
776
  device)
777
- return prompt_after_apply_instruction, prompt_after_apply_instruction
778
 
779
 
780
  def process_mask(input_image,
@@ -1415,7 +1430,7 @@ def init_img(base,
1415
  original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
1416
  return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "", "Custom resolution", False, False, example_change_times
1417
  else:
1418
- return base, original_image, None, "", None, None, None, "", "", "", aspect_ratio, True, False, 0
1419
 
1420
 
1421
  def reset_func(input_image,
@@ -1423,7 +1438,7 @@ def reset_func(input_image,
1423
  original_mask,
1424
  prompt,
1425
  target_prompt,
1426
- target_prompt_output):
1427
  input_image = None
1428
  original_image = None
1429
  original_mask = None
@@ -1432,10 +1447,9 @@ def reset_func(input_image,
1432
  masked_gallery = []
1433
  result_gallery = []
1434
  target_prompt = ''
1435
- target_prompt_output = ''
1436
  if torch.cuda.is_available():
1437
  torch.cuda.empty_cache()
1438
- return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, target_prompt_output, True, False
1439
 
1440
 
1441
  def update_example(example_type,
@@ -1458,7 +1472,8 @@ def update_example(example_type,
1458
  original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
1459
  aspect_ratio = "Custom resolution"
1460
  example_change_times += 1
1461
- return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", "", False, example_change_times
 
1462
 
1463
  block = gr.Blocks(
1464
  theme=gr.themes.Soft(
@@ -1498,6 +1513,8 @@ with block as demo:
1498
  sources=["upload"],
1499
  )
1500
 
 
 
1501
 
1502
  vlm_model_dropdown = gr.Dropdown(label="VLM model", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
1503
  with gr.Group():
@@ -1510,12 +1527,6 @@ with block as demo:
1510
  aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
1511
  resize_default = gr.Checkbox(label="Short edge resize to 640px", value=True)
1512
 
1513
-
1514
- prompt = gr.Textbox(label="⌨️ Instruction", placeholder="Please input your instruction.", value="",lines=1)
1515
-
1516
- run_button = gr.Button("💫 Run")
1517
-
1518
-
1519
  with gr.Row():
1520
  mask_button = gr.Button("Generate Mask")
1521
  random_mask_button = gr.Button("Square/Circle Mask ")
@@ -1603,7 +1614,7 @@ with block as demo:
1603
  with gr.Tab(elem_classes="feedback", label="Output"):
1604
  result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", preview=True, height=400)
1605
 
1606
- target_prompt_output = gr.Text(label="Output Target Prompt", value="", lines=1, interactive=False)
1607
 
1608
  reset_button = gr.Button("Reset")
1609
 
@@ -1634,9 +1645,9 @@ with block as demo:
1634
  input_image.upload(
1635
  init_img,
1636
  [input_image, init_type, prompt, aspect_ratio, example_change_times],
1637
- [input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, target_prompt_output, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times]
1638
  )
1639
- example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, target_prompt_output, invert_mask_state, example_change_times])
1640
 
1641
  ## vlm and base model dropdown
1642
  vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
@@ -1666,7 +1677,7 @@ with block as demo:
1666
  invert_mask_state]
1667
 
1668
  ## run brushedit
1669
- run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, target_prompt_output, invert_mask_state])
1670
 
1671
  ## mask func
1672
  mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
@@ -1681,10 +1692,10 @@ with block as demo:
1681
  move_down_button.click(fn=move_mask_down, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1682
 
1683
  ## prompt func
1684
- generate_target_prompt_button.click(fn=generate_target_prompt, inputs=[input_image, original_image, prompt], outputs=[target_prompt, target_prompt_output])
1685
 
1686
  ## reset func
1687
- reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt, target_prompt_output], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, target_prompt_output, resize_default, invert_mask_state])
1688
 
1689
 
1690
  demo.launch()
 
337
  if vlm_processor != "" and vlm_model != "":
338
  vlm_model.to(device)
339
  else:
340
+ raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
341
 
342
 
343
  ## init base model
 
504
  dilated_mask = np.zeros_like(binary_mask, dtype=bool)
505
  dilated_mask[ellipse_mask] = True
506
  else:
507
+ ValueError("dilation_type must be 'square' or 'ellipse'")
508
 
509
  # use binary dilation
510
  dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
 
637
  image_pil = input_image["background"].convert("RGB")
638
  original_image = np.array(image_pil)
639
  if prompt is None or prompt == "":
640
+ if target_prompt is None or target_prompt == "":
641
+ raise gr.Error("Please input your instructions, e.g., remove the xxx")
642
 
643
  alpha_mask = input_image["layers"][0].split()[3]
644
  input_mask = np.asarray(alpha_mask)
 
688
  original_mask = input_mask
689
 
690
 
691
+ ## inpainting directly if target_prompt is not None
692
  if category is not None:
693
+ pass
694
+ elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None:
695
+ pass
696
  else:
697
+ try:
698
+ category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
699
+ except Exception as e:
700
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
701
 
702
+
703
  if original_mask is not None:
704
  original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
705
  else:
706
+ try:
707
+ object_wait_for_edit = vlm_response_object_wait_for_edit(
708
  vlm_processor,
709
  vlm_model,
710
  original_image,
 
712
  prompt,
713
  device)
714
 
715
+ original_mask = vlm_response_mask(vlm_processor,
716
+ vlm_model,
717
+ category,
718
+ original_image,
719
+ prompt,
720
+ object_wait_for_edit,
721
+ sam,
722
+ sam_predictor,
723
+ sam_automask_generator,
724
+ groundingdino_model,
725
+ device)
726
+ except Exception as e:
727
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
728
+
729
  if original_mask.ndim == 2:
730
  original_mask = original_mask[:,:,None]
731
 
732
 
733
+ if target_prompt is not None and len(target_prompt) >= 1:
734
+ prompt_after_apply_instruction = target_prompt
735
+
736
+ else:
737
+ try:
738
+ prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
739
  vlm_processor,
740
  vlm_model,
741
  original_image,
742
  prompt,
743
  device)
744
+ except Exception as e:
745
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
746
 
747
  generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
748
 
 
772
  # image[3].save(f"outputs/image_edit_{uuid}_3.png")
773
  # mask_image.save(f"outputs/mask_{uuid}.png")
774
  # masked_image.save(f"outputs/masked_image_{uuid}.png")
775
+ # gr.Info(f"Target Prompt: {prompt_after_apply_instruction}", duration=16)
776
+ return image, [mask_image], [masked_image], prompt, '', False
777
 
778
 
779
  def generate_target_prompt(input_image,
 
789
  original_image,
790
  prompt,
791
  device)
792
+ return prompt_after_apply_instruction
793
 
794
 
795
  def process_mask(input_image,
 
1430
  original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
1431
  return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "", "Custom resolution", False, False, example_change_times
1432
  else:
1433
+ return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0
1434
 
1435
 
1436
  def reset_func(input_image,
 
1438
  original_mask,
1439
  prompt,
1440
  target_prompt,
1441
+ ):
1442
  input_image = None
1443
  original_image = None
1444
  original_mask = None
 
1447
  masked_gallery = []
1448
  result_gallery = []
1449
  target_prompt = ''
 
1450
  if torch.cuda.is_available():
1451
  torch.cuda.empty_cache()
1452
+ return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, True, False
1453
 
1454
 
1455
  def update_example(example_type,
 
1472
  original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
1473
  aspect_ratio = "Custom resolution"
1474
  example_change_times += 1
1475
+ return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times
1476
+
1477
 
1478
  block = gr.Blocks(
1479
  theme=gr.themes.Soft(
 
1513
  sources=["upload"],
1514
  )
1515
 
1516
+ prompt = gr.Textbox(label="⌨️ Instruction", placeholder="Please input your instruction.", value="",lines=1)
1517
+ run_button = gr.Button("💫 Run")
1518
 
1519
  vlm_model_dropdown = gr.Dropdown(label="VLM model", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
1520
  with gr.Group():
 
1527
  aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
1528
  resize_default = gr.Checkbox(label="Short edge resize to 640px", value=True)
1529
 
 
 
 
 
 
 
1530
  with gr.Row():
1531
  mask_button = gr.Button("Generate Mask")
1532
  random_mask_button = gr.Button("Square/Circle Mask ")
 
1614
  with gr.Tab(elem_classes="feedback", label="Output"):
1615
  result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", preview=True, height=400)
1616
 
1617
+ # target_prompt_output = gr.Text(label="Output Target Prompt", value="", lines=1, interactive=False)
1618
 
1619
  reset_button = gr.Button("Reset")
1620
 
 
1645
  input_image.upload(
1646
  init_img,
1647
  [input_image, init_type, prompt, aspect_ratio, example_change_times],
1648
+ [input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times]
1649
  )
1650
+ example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times])
1651
 
1652
  ## vlm and base model dropdown
1653
  vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
 
1677
  invert_mask_state]
1678
 
1679
  ## run brushedit
1680
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state])
1681
 
1682
  ## mask func
1683
  mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
 
1692
  move_down_button.click(fn=move_mask_down, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1693
 
1694
  ## prompt func
1695
+ generate_target_prompt_button.click(fn=generate_target_prompt, inputs=[input_image, original_image, prompt], outputs=[target_prompt])
1696
 
1697
  ## reset func
1698
+ reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, resize_default, invert_mask_state])
1699
 
1700
 
1701
  demo.launch()
app/src/vlm_pipeline.py CHANGED
@@ -98,10 +98,12 @@ def vlm_response_editing_type(vlm_processor,
98
  messages = create_editing_category_messages_qwen2(editing_prompt)
99
  response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device=device)
100
 
101
- for category_name in ["Addition","Remove","Local","Global","Background"]:
102
- if category_name.lower() in response_str.lower():
103
- return category_name
104
- raise gr.Error("Please input correct commands, including add, delete, and modify commands. If it still does not work, please switch to a more powerful VLM.")
 
 
105
 
106
 
107
  ### response object to be edited
@@ -206,17 +208,21 @@ def vlm_response_prompt_after_apply_instruction(vlm_processor,
206
  image,
207
  editing_prompt,
208
  device):
209
- if isinstance(vlm_model, OpenAI):
210
- base64_image = encode_image(image)
211
- messages = create_apply_editing_messages_gpt4o(editing_prompt, base64_image)
212
- response_str = run_gpt4o_vl_inference(vlm_model, messages)
213
- elif isinstance(vlm_model, LlavaNextForConditionalGeneration):
214
- messages = create_apply_editing_messages_llava(editing_prompt)
215
- response_str = run_llava_next_inference(vlm_processor, vlm_model, messages, image, device)
216
- elif isinstance(vlm_model, Qwen2VLForConditionalGeneration):
217
- base64_image = encode_image(image)
218
- messages = create_apply_editing_messages_qwen2(editing_prompt, base64_image)
219
- response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device)
220
- else:
221
- raise gr.Error("Please select the correct VLM model!")
222
- return response_str
 
 
 
 
 
98
  messages = create_editing_category_messages_qwen2(editing_prompt)
99
  response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device=device)
100
 
101
+ try:
102
+ for category_name in ["Addition","Remove","Local","Global","Background"]:
103
+ if category_name.lower() in response_str.lower():
104
+ return category_name
105
+ except Exception as e:
106
+ raise gr.Error("Please input OpenAI API Key. Or please input correct commands, including add, delete, and modify commands. If it still does not work, please switch to a more powerful VLM.")
107
 
108
 
109
  ### response object to be edited
 
208
  image,
209
  editing_prompt,
210
  device):
211
+
212
+ try:
213
+ if isinstance(vlm_model, OpenAI):
214
+ base64_image = encode_image(image)
215
+ messages = create_apply_editing_messages_gpt4o(editing_prompt, base64_image)
216
+ response_str = run_gpt4o_vl_inference(vlm_model, messages)
217
+ elif isinstance(vlm_model, LlavaNextForConditionalGeneration):
218
+ messages = create_apply_editing_messages_llava(editing_prompt)
219
+ response_str = run_llava_next_inference(vlm_processor, vlm_model, messages, image, device)
220
+ elif isinstance(vlm_model, Qwen2VLForConditionalGeneration):
221
+ base64_image = encode_image(image)
222
+ messages = create_apply_editing_messages_qwen2(editing_prompt, base64_image)
223
+ response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device)
224
+ else:
225
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
226
+ except Exception as e:
227
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
228
+ return response_str