Niki Zhang commited on
Commit
ecd56c8
·
verified ·
1 Parent(s): 711583c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -32
app.py CHANGED
@@ -480,28 +480,28 @@ tools_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.cha
480
  shared_chatbot_tools = build_chatbot_tools(tools_dict)
481
 
482
 
483
- class ImageSketcher(gr.Image):
484
- """
485
- Fix the bug of gradio.Image that cannot upload with tool == 'sketch'.
486
- """
487
-
488
- is_template = True # Magic to make this work with gradio.Block, don't remove unless you know what you're doing.
489
-
490
- def __init__(self, **kwargs):
491
- super().__init__(**kwargs)
492
-
493
- def preprocess(self, x):
494
- if self.tool == 'sketch' and self.source in ["upload", "webcam"]:
495
- assert isinstance(x, dict)
496
- if x['mask'] is None:
497
- decode_image = processing_utils.decode_base64_to_image(x['image'])
498
- width, height = decode_image.size
499
- mask = np.zeros((height, width, 4), dtype=np.uint8)
500
- mask[..., -1] = 255
501
- mask = self.postprocess(mask)
502
- x['mask'] = mask
503
 
504
- return super().preprocess(x)
505
 
506
 
507
  def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, ocr_reader=None, text_refiner=None,
@@ -908,15 +908,13 @@ submit_traj=0
908
 
909
  async def inference_traject(origin_image,sketcher_image, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
910
  original_size, input_size, text_refiner,focus_type,paragraph,openai_api_key,autoplay,trace_type):
911
- image_input, mask = sketcher_image['image'], sketcher_image['mask']
912
 
913
  crop_save_path=""
914
 
915
  prompt = get_sketch_prompt(mask)
916
  boxes = prompt['input_boxes']
917
  boxes = boxes[0]
918
- global submit_traj
919
- submit_traj=1
920
 
921
  controls = {'length': length,
922
  'sentiment': sentiment,
@@ -962,11 +960,7 @@ async def inference_traject(origin_image,sketcher_image, enable_wiki, language,
962
  # image_input = create_bubble_frame(image_input, "", fake_click_index, input_mask)
963
 
964
  prompt=generate_prompt(focus_type, paragraph, length, sentiment, factuality, language)
965
- width, height = sketcher_image['image'].size
966
- sketcher_image['mask'] = np.zeros((height, width, 4), dtype=np.uint8)
967
- sketcher_image['mask'][..., -1] = 255
968
- sketcher_image['image']=image_input
969
-
970
 
971
  # if not args.disable_gpt and text_refiner:
972
  if not args.disable_gpt:
@@ -1345,12 +1339,13 @@ def create_ui():
1345
  with gr.Tab("Trajectory (beta)") as traj_tab:
1346
  # sketcher_input = ImageSketcher(type="pil", interactive=True, brush_radius=10,
1347
  # elem_id="image_sketcher")
1348
- sketcher_input = ImageSketcher(type="pil", interactive=True,
1349
  elem_id="image_sketcher")
1350
  example_image = gr.Image(type="pil", interactive=False, visible=False)
1351
- with gr.Row():
1352
- submit_button_sketcher = gr.Button(value="Submit", interactive=True)
1353
  clear_button_sketcher = gr.Button(value="Clear Sketch", interactive=True)
 
 
1354
  with gr.Row():
1355
  with gr.Row():
1356
  focus_type_sketch = gr.Radio(
 
480
  shared_chatbot_tools = build_chatbot_tools(tools_dict)
481
 
482
 
483
+ # class ImageSketcher(gr.Image):
484
+ # """
485
+ # Fix the bug of gradio.Image that cannot upload with tool == 'sketch'.
486
+ # """
487
+
488
+ # is_template = True # Magic to make this work with gradio.Block, don't remove unless you know what you're doing.
489
+
490
+ # def __init__(self, **kwargs):
491
+ # super().__init__(**kwargs)
492
+
493
+ # def preprocess(self, x):
494
+ # if self.tool == 'sketch' and self.source in ["upload", "webcam"]:
495
+ # assert isinstance(x, dict)
496
+ # if x['mask'] is None:
497
+ # decode_image = processing_utils.decode_base64_to_image(x['image'])
498
+ # width, height = decode_image.size
499
+ # mask = np.zeros((height, width, 4), dtype=np.uint8)
500
+ # mask[..., -1] = 255
501
+ # mask = self.postprocess(mask)
502
+ # x['mask'] = mask
503
 
504
+ # return super().preprocess(x)
505
 
506
 
507
  def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, ocr_reader=None, text_refiner=None,
 
908
 
909
  async def inference_traject(origin_image,sketcher_image, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
910
  original_size, input_size, text_refiner,focus_type,paragraph,openai_api_key,autoplay,trace_type):
911
+ image_input, mask = sketcher_image['background'], sketcher_image['layers'][0]
912
 
913
  crop_save_path=""
914
 
915
  prompt = get_sketch_prompt(mask)
916
  boxes = prompt['input_boxes']
917
  boxes = boxes[0]
 
 
918
 
919
  controls = {'length': length,
920
  'sentiment': sentiment,
 
960
  # image_input = create_bubble_frame(image_input, "", fake_click_index, input_mask)
961
 
962
  prompt=generate_prompt(focus_type, paragraph, length, sentiment, factuality, language)
963
+
 
 
 
 
964
 
965
  # if not args.disable_gpt and text_refiner:
966
  if not args.disable_gpt:
 
1339
  with gr.Tab("Trajectory (beta)") as traj_tab:
1340
  # sketcher_input = ImageSketcher(type="pil", interactive=True, brush_radius=10,
1341
  # elem_id="image_sketcher")
1342
+ sketcher_input = gr.ImageEditor(type="pil", interactive=True,
1343
  elem_id="image_sketcher")
1344
  example_image = gr.Image(type="pil", interactive=False, visible=False)
1345
+ with gr.Row():
 
1346
  clear_button_sketcher = gr.Button(value="Clear Sketch", interactive=True)
1347
+ submit_button_sketcher = gr.Button(value="Submit", interactive=True)
1348
+
1349
  with gr.Row():
1350
  with gr.Row():
1351
  focus_type_sketch = gr.Radio(