praeclarumjj3 commited on
Commit
a7e7927
1 Parent(s): 8684306

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -33
app.py CHANGED
@@ -48,11 +48,7 @@ function() {
48
  def load_demo_refresh_model_list(request: gr.Request):
49
  logger.info(f"load_demo. ip: {request.client.host}")
50
  state = default_conversation.copy()
51
- dropdown_update = gr.Dropdown(
52
- choices=models,
53
- value=models[0]+"-4bit" if len(models) > 0 else ""
54
- )
55
- return state, dropdown_update
56
 
57
 
58
  def vote_last_response(state, vote_type, model_selector, request: gr.Request):
@@ -80,48 +76,51 @@ def flag_last_response(state, model_selector, request: gr.Request):
80
  vote_last_response(state, "flag", model_selector, request)
81
  return ("",) + (disable_btn,) * 3
82
 
83
- def regenerate(state, image_process_mode, seg_process_mode):
84
  state.messages[-1][-1] = None
85
  prev_human_msg = state.messages[-2]
86
  if type(prev_human_msg[1]) in (tuple, list):
87
- prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode, prev_human_msg[1][3], seg_process_mode, None, None)
88
  state.skip_next = False
89
- return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
90
 
91
 
92
  def clear_history(request: gr.Request):
93
  state = default_conversation.copy()
94
- return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
95
 
96
 
97
- def add_text(state, text, image, image_process_mode, seg, seg_process_mode, request: gr.Request):
98
  logger.info(f"add_text. len: {len(text)}")
99
  if len(text) <= 0 and image is None:
100
  state.skip_next = True
101
- return (state, state.to_gradio_chatbot(), "", None, None) + (no_change_btn,) * 5
102
  if args.moderate:
103
  flagged = violates_moderation(text)
104
  if flagged:
105
  state.skip_next = True
106
- return (state, state.to_gradio_chatbot(), moderation_msg, None, None) + (
107
  no_change_btn,) * 5
108
 
109
- text = text[:1576] # Hard cut-off
110
  if image is not None:
111
- text = text[:1200] # Hard cut-off for images
112
  if '<image>' not in text:
113
  text = '<image>\n' + text
114
  if seg is not None:
115
  if '<seg>' not in text:
116
  text = '<seg>\n' + text
 
 
 
117
 
118
- text = (text, image, image_process_mode, seg, seg_process_mode, None, None)
119
  if len(state.get_images(return_pil=True)) > 0:
120
  state = default_conversation.copy()
121
  state.append_message(state.roles[0], text)
122
  state.append_message(state.roles[1], None)
123
  state.skip_next = False
124
- return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
125
 
126
 
127
  def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
@@ -155,11 +154,13 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request:
155
  "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
156
  "images": f'List of {len(state.get_images())}',
157
  "segs": f'List of {len(state.get_segs())}',
 
158
  }
159
  logger.info(f"==== request ====\n{pload}")
160
 
161
  pload['images'] = state.get_images()
162
  pload['segs'] = state.get_segs()
 
163
 
164
  state.messages[-1][-1] = "▌"
165
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
@@ -189,8 +190,6 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request:
189
 
190
  state.messages[-1][-1] = state.messages[-1][-1][:-1]
191
  yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
192
-
193
- finish_tstamp = time.time()
194
  logger.info(f"{output}")
195
 
196
 
@@ -235,7 +234,7 @@ def build_demo(embed_mode):
235
  with gr.Row(elem_id="model_selector_row"):
236
  model_selector = gr.Dropdown(
237
  choices=models,
238
- value=models[0] if len(models) > 0 else "",
239
  interactive=True,
240
  show_label=False,
241
  container=False)
@@ -252,6 +251,12 @@ def build_demo(embed_mode):
252
  ["Crop", "Resize", "Pad", "Default"],
253
  value="Default",
254
  label="Preprocess for non-square Seg Map", visible=False)
 
 
 
 
 
 
255
 
256
  with gr.Accordion("Parameters", open=False) as parameter_row:
257
  temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.8, step=0.1, interactive=True, label="Temperature",)
@@ -275,13 +280,8 @@ def build_demo(embed_mode):
275
 
276
  cur_dir = os.path.dirname(os.path.abspath(__file__))
277
  gr.Examples(examples=[
278
- [f"{cur_dir}/examples/people.jpg", f"{cur_dir}/examples/people_pan.png", "What objects can be seen in the image?", "0.9", "1.0"],
279
- [f"{cur_dir}/examples/corgi.jpg", f"{cur_dir}/examples/corgi_pan.png", "What objects can be seen in the image?", "0.6", "0.7"],
280
- [f"{cur_dir}/examples/friends.jpg", f"{cur_dir}/examples/friends_pan.png", "Can you count the number of people in the image?", "0.8", "0.9"],
281
- [f"{cur_dir}/examples/friends.jpg", f"{cur_dir}/examples/friends_pan.png", "What is happening in the image?", "0.8", "0.9"],
282
- [f"{cur_dir}/examples/suits.jpg", f"{cur_dir}/examples/suits_pan.png", "What objects can be seen in the image?", "0.5", "0.5"],
283
- [f"{cur_dir}/examples/suits.jpg", f"{cur_dir}/examples/suits_ins.png", "What objects can be seen in the image?", "0.5", "0.5"],
284
- ], inputs=[imagebox, segbox, textbox, temperature, top_p])
285
 
286
  if not embed_mode:
287
  gr.Markdown(tos_markdown)
@@ -295,20 +295,20 @@ def build_demo(embed_mode):
295
  [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
296
  flag_btn.click(flag_last_response,
297
  [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
298
- regenerate_btn.click(regenerate, [state, image_process_mode, seg_process_mode],
299
- [state, chatbot, textbox, imagebox, segbox] + btn_list).then(
300
  http_bot, [state, model_selector, temperature, top_p, max_output_tokens],
301
  [state, chatbot] + btn_list)
302
- clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox, segbox] + btn_list)
303
 
304
- textbox.submit(add_text, [state, textbox, imagebox, image_process_mode, segbox, seg_process_mode], [state, chatbot, textbox, imagebox, segbox] + btn_list
305
  ).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens],
306
  [state, chatbot] + btn_list)
307
- submit_btn.click(add_text, [state, textbox, imagebox, image_process_mode, segbox, seg_process_mode], [state, chatbot, textbox, imagebox, segbox] + btn_list
308
  ).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens],
309
  [state, chatbot] + btn_list)
310
 
311
- demo.load(load_demo_refresh_model_list, None, [state, model_selector])
312
 
313
  return demo
314
 
@@ -354,7 +354,10 @@ if __name__ == "__main__":
354
 
355
  logger.info(args)
356
  demo = build_demo(args.embed)
357
- demo.queue().launch(
 
 
 
358
  server_name=args.host,
359
  server_port=args.port,
360
  share=args.share
 
48
  def load_demo_refresh_model_list(request: gr.Request):
49
  logger.info(f"load_demo. ip: {request.client.host}")
50
  state = default_conversation.copy()
51
+ return state
 
 
 
 
52
 
53
 
54
  def vote_last_response(state, vote_type, model_selector, request: gr.Request):
 
76
  vote_last_response(state, "flag", model_selector, request)
77
  return ("",) + (disable_btn,) * 3
78
 
79
+ def regenerate(state, image_process_mode, seg_process_mode, depth_process_mode):
80
  state.messages[-1][-1] = None
81
  prev_human_msg = state.messages[-2]
82
  if type(prev_human_msg[1]) in (tuple, list):
83
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode, prev_human_msg[1][3], seg_process_mode, prev_human_msg[1][5], depth_process_mode)
84
  state.skip_next = False
85
+ return (state, state.to_gradio_chatbot(), "", None, None, None, None) + (disable_btn,) * 5
86
 
87
 
88
  def clear_history(request: gr.Request):
89
  state = default_conversation.copy()
90
+ return (state, state.to_gradio_chatbot(), "", None, None, None, None) + (disable_btn,) * 5
91
 
92
 
93
+ def add_text(state, text, image, image_process_mode, seg, seg_process_mode, depth, depth_process_mode, request: gr.Request):
94
  logger.info(f"add_text. len: {len(text)}")
95
  if len(text) <= 0 and image is None:
96
  state.skip_next = True
97
+ return (state, state.to_gradio_chatbot(), "", None, None, None, None) + (no_change_btn,) * 5
98
  if args.moderate:
99
  flagged = violates_moderation(text)
100
  if flagged:
101
  state.skip_next = True
102
+ return (state, state.to_gradio_chatbot(), moderation_msg, None, None, None, None) + (
103
  no_change_btn,) * 5
104
 
105
+ text = text[:1200] # Hard cut-off
106
  if image is not None:
107
+ text = text[:864] # Hard cut-off for images
108
  if '<image>' not in text:
109
  text = '<image>\n' + text
110
  if seg is not None:
111
  if '<seg>' not in text:
112
  text = '<seg>\n' + text
113
+ if depth is not None:
114
+ if '<depth>' not in text:
115
+ text = '<depth>\n' + text
116
 
117
+ text = (text, image, image_process_mode, seg, seg_process_mode, depth, depth_process_mode)
118
  if len(state.get_images(return_pil=True)) > 0:
119
  state = default_conversation.copy()
120
  state.append_message(state.roles[0], text)
121
  state.append_message(state.roles[1], None)
122
  state.skip_next = False
123
+ return (state, state.to_gradio_chatbot(), "", None, None, None, None) + (disable_btn,) * 5
124
 
125
 
126
  def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
 
154
  "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
155
  "images": f'List of {len(state.get_images())}',
156
  "segs": f'List of {len(state.get_segs())}',
157
+ "depths": f'List of {len(state.get_depths())}',
158
  }
159
  logger.info(f"==== request ====\n{pload}")
160
 
161
  pload['images'] = state.get_images()
162
  pload['segs'] = state.get_segs()
163
+ pload['depths'] = state.get_depths()
164
 
165
  state.messages[-1][-1] = "▌"
166
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
 
190
 
191
  state.messages[-1][-1] = state.messages[-1][-1][:-1]
192
  yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
 
 
193
  logger.info(f"{output}")
194
 
195
 
 
234
  with gr.Row(elem_id="model_selector_row"):
235
  model_selector = gr.Dropdown(
236
  choices=models,
237
+ value=models[0]+"-4bit" if len(models) > 0 else "",
238
  interactive=True,
239
  show_label=False,
240
  container=False)
 
251
  ["Crop", "Resize", "Pad", "Default"],
252
  value="Default",
253
  label="Preprocess for non-square Seg Map", visible=False)
254
+
255
+ depthbox = gr.Image(type="pil", label="Depth Map")
256
+ depth_process_mode = gr.Radio(
257
+ ["Crop", "Resize", "Pad", "Default"],
258
+ value="Default",
259
+ label="Preprocess for non-square Depth Map", visible=False)
260
 
261
  with gr.Accordion("Parameters", open=False) as parameter_row:
262
  temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.8, step=0.1, interactive=True, label="Temperature",)
 
280
 
281
  cur_dir = os.path.dirname(os.path.abspath(__file__))
282
  gr.Examples(examples=[
283
+ [f"{cur_dir}/examples/suits.jpg", f"{cur_dir}/examples/suits_pan.png", f"{cur_dir}/examples/suits_depth.jpeg", "Can you describe the depth order of the objects in this image, from closest to farthest?", "0.5", "0.5"],
284
+ ], inputs=[imagebox, segbox, depthbox, textbox, temperature, top_p])
 
 
 
 
 
285
 
286
  if not embed_mode:
287
  gr.Markdown(tos_markdown)
 
295
  [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
296
  flag_btn.click(flag_last_response,
297
  [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
298
+ regenerate_btn.click(regenerate, [state, image_process_mode, seg_process_mode, depth_process_mode],
299
+ [state, chatbot, textbox, imagebox, segbox, depthbox] + btn_list).then(
300
  http_bot, [state, model_selector, temperature, top_p, max_output_tokens],
301
  [state, chatbot] + btn_list)
302
+ clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox, segbox, depthbox] + btn_list)
303
 
304
+ textbox.submit(add_text, [state, textbox, imagebox, image_process_mode, segbox, seg_process_mode, depthbox, depth_process_mode], [state, chatbot, textbox, imagebox, segbox, depthbox] + btn_list
305
  ).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens],
306
  [state, chatbot] + btn_list)
307
+ submit_btn.click(add_text, [state, textbox, imagebox, image_process_mode, segbox, seg_process_mode, depthbox, depth_process_mode], [state, chatbot, textbox, imagebox, segbox, depthbox] + btn_list
308
  ).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens],
309
  [state, chatbot] + btn_list)
310
 
311
+ demo.load(load_demo_refresh_model_list, None, [state])
312
 
313
  return demo
314
 
 
354
 
355
  logger.info(args)
356
  demo = build_demo(args.embed)
357
+ demo.queue(
358
+ concurrency_count=args.concurrency_count,
359
+ api_open=False
360
+ ).launch(
361
  server_name=args.host,
362
  server_port=args.port,
363
  share=args.share