玙珲 commited on
Commit
e0ca852
·
1 Parent(s): 169061d

add thinking budget

Browse files
Files changed (1) hide show
  1. app.py +78 -30
app.py CHANGED
@@ -27,6 +27,25 @@ streamer = None
27
  # This should point to the directory containing your SVG file.
28
  CUR_DIR = os.path.dirname(os.path.abspath(__file__))
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def submit_chat(chatbot, text_input):
31
  response = ''
32
  chatbot.append([text_input, response])
@@ -114,6 +133,8 @@ def run_inference(
114
  do_sample: bool,
115
  max_new_tokens: int,
116
  enable_thinking: bool,
 
 
117
  ):
118
  """
119
  Runs a single turn of inference and yields the output stream for a gr.Chatbot.
@@ -122,14 +143,11 @@ def run_inference(
122
  prompt = chatbot[-1][0]
123
  if (not image_input and not video_input and not prompt) or not prompt:
124
  gr.Warning("A text prompt is required for generation.")
 
125
  # MODIFICATION: Yield the current state and return to avoid errors
126
  yield chatbot
127
  return
128
 
129
- # MODIFICATION: Append the new prompt to the existing history
130
- # chatbot.append([prompt, ""])
131
- # yield chatbot, "" # Yield the updated chat to show the user's prompt immediately
132
-
133
  content = []
134
  if image_input:
135
  content.append({"type": "image", "image": image_input})
@@ -139,7 +157,7 @@ def run_inference(
139
  content.append({"type": "video", "video": frames})
140
  else:
141
  gr.Warning("Failed to process the video file.")
142
- chatbot[-1][1] = "Error: Could not process the video file."
143
  yield chatbot
144
  return
145
 
@@ -154,7 +172,8 @@ def run_inference(
154
  else:
155
  input_ids, pixel_values, grid_thws = model.preprocess_inputs(messages=messages, add_generation_prompt=True, enable_thinking=enable_thinking)
156
  except Exception as e:
157
- chatbot[-1][1] = f"Error during input preprocessing: {e}"
 
158
  yield chatbot
159
  return
160
 
@@ -170,7 +189,10 @@ def run_inference(
170
  "eos_token_id": model.text_tokenizer.eos_token_id,
171
  "pad_token_id": model.text_tokenizer.pad_token_id,
172
  "streamer": streamer,
173
- "use_cache": True
 
 
 
174
  }
175
 
176
  with torch.inference_mode():
@@ -197,16 +219,11 @@ def run_inference(
197
  chatbot[-1][1] = formatted_response
198
  yield chatbot # Yield the final, formatted response
199
 
200
- logger.info("[OVIS_CONV_START]")
201
- [print(f'Q{i}:\n {request}\nA{i}:\n {answer}') for i, (request, answer) in enumerate(chatbot, 1)]
202
- # print('New_Q:\n', text_input)
203
- # print('New_A:\n', response)
204
  logger.info("[OVIS_CONV_END]")
205
 
206
 
207
- def clear_chat():
208
- return [], None, ""
209
-
210
  # --- UI Helper Functions ---
211
  def toggle_media_input(choice: str) -> Tuple:
212
  """Switches visibility between Image/Video inputs and their corresponding examples."""
@@ -217,7 +234,6 @@ def toggle_media_input(choice: str) -> Tuple:
217
 
218
 
219
  # --- Build Gradio Application ---
220
- # @spaces.GPU
221
  def build_demo(model_path: str):
222
  """Builds the Gradio user interface for the model."""
223
  global model, streamer
@@ -231,7 +247,7 @@ def build_demo(model_path: str):
231
  ).to(device).eval()
232
 
233
  text_tokenizer = model.text_tokenizer
234
- streamer = TextIteratorStreamer(text_tokenizer, skip_prompt=True, skip_special_tokens=True)
235
 
236
  print("Model loaded successfully.")
237
 
@@ -257,10 +273,22 @@ def build_demo(model_path: str):
257
  <center><font size=3><b>Ovis</b> has been open-sourced on <a href='https://huggingface.co/{model_path}'>😊 Huggingface</a> and <a href='https://github.com/AIDC-AI/Ovis'>🌟 GitHub</a>. If you find Ovis useful, a like❤️ or a star🌟 would be appreciated.</font></center>
258
  """
259
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  prompt_input = gr.Textbox(label="Prompt", placeholder="Enter your text here and press ENTER", lines=1, container=False)
261
  with gr.Blocks(theme=gr.themes.Ocean()) as demo:
262
  gr.HTML(html_header)
263
- gr.Markdown("Note: you might have to increase \"Max New Tokens\" and wait longer to obtain answer when Deep Thinking is enabled.")
264
 
265
  with gr.Row():
266
  with gr.Column(scale=4):
@@ -270,10 +298,10 @@ def build_demo(model_path: str):
270
 
271
  with gr.Accordion("Generation Settings", open=True):
272
  do_sample = gr.Checkbox(label="Enable Sampling (Do Sample)", value=True)
273
- max_new_tokens = gr.Slider(minimum=32, maximum=4096, value=2048, step=32, label="Max New Tokens")
274
- enable_thinking = gr.Checkbox(label="Enable Deep Thinking", value=False)
275
-
276
-
277
 
278
  with gr.Column(visible=True) as image_examples_col:
279
  gr.Examples(
@@ -297,30 +325,50 @@ def build_demo(model_path: str):
297
  generate_btn = gr.Button("Send", variant="primary")
298
  clear_btn = gr.Button("Clear", variant="secondary")
299
 
 
 
300
  input_type_radio.change(
301
  fn=toggle_media_input,
302
  inputs=input_type_radio,
303
  outputs=[image_input, video_input, image_examples_col, video_examples_col]
304
  )
305
 
306
- # MODIFICATION: Update event handlers to use the new function and manage state
307
- run_inputs = [chatbot, image_input, video_input, do_sample, max_new_tokens, enable_thinking]
308
- # run_outputs = [image_input, prompt_input]
 
 
 
 
 
 
 
 
 
 
 
309
 
310
  generat_click_event = generate_btn.click(submit_chat, [chatbot, prompt_input], [chatbot, prompt_input]).then(run_inference, run_inputs, chatbot)
311
  submit_event = prompt_input.submit(submit_chat, [chatbot, prompt_input], [chatbot, prompt_input]).then(run_inference, run_inputs, chatbot)
312
 
 
 
 
 
 
 
 
 
 
313
  clear_btn.click(
314
- fn=lambda: ([], None, None, "", "Image", True, 1024, False),
315
- outputs=[chatbot, image_input, video_input, prompt_input, input_type_radio, do_sample, max_new_tokens, enable_thinking]
316
- ).then(
317
- fn=toggle_media_input,
318
- inputs=input_type_radio,
319
- outputs=[image_input, video_input, image_examples_col, video_examples_col]
320
  )
 
321
 
322
  return demo
323
 
 
324
  # --- Main Execution Block ---
325
  # def parse_args():
326
  # parser = argparse.ArgumentParser(description="Gradio interface for a single Multimodal Large Language Model.")
 
27
  # This should point to the directory containing your SVG file.
28
  CUR_DIR = os.path.dirname(os.path.abspath(__file__))
29
 
30
+
31
+ class MyTextIteratorStreamer(TextIteratorStreamer):
32
+ def manual_end(self):
33
+ """Flushes any remaining cache and prints a newline to stdout."""
34
+ # Flush the cache, if it exists
35
+ if len(self.token_cache) > 0:
36
+ text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
37
+ printable_text = text[self.print_len :]
38
+ self.token_cache = []
39
+ self.print_len = 0
40
+ else:
41
+ printable_text = ""
42
+
43
+ self.next_tokens_are_prompt = True
44
+ self.on_finalized_text(printable_text, stream_end=True)
45
+
46
+ def end(self):
47
+ pass
48
+
49
  def submit_chat(chatbot, text_input):
50
  response = ''
51
  chatbot.append([text_input, response])
 
133
  do_sample: bool,
134
  max_new_tokens: int,
135
  enable_thinking: bool,
136
+ enable_thinking_budget: bool, # NEWLY ADDED
137
+ thinking_budget: int, # NEWLY ADDED
138
  ):
139
  """
140
  Runs a single turn of inference and yields the output stream for a gr.Chatbot.
 
143
  prompt = chatbot[-1][0]
144
  if (not image_input and not video_input and not prompt) or not prompt:
145
  gr.Warning("A text prompt is required for generation.")
146
+ chatbot.pop(-1)
147
  # MODIFICATION: Yield the current state and return to avoid errors
148
  yield chatbot
149
  return
150
 
 
 
 
 
151
  content = []
152
  if image_input:
153
  content.append({"type": "image", "image": image_input})
 
157
  content.append({"type": "video", "video": frames})
158
  else:
159
  gr.Warning("Failed to process the video file.")
160
+ chatbot.pop(-1)
161
  yield chatbot
162
  return
163
 
 
172
  else:
173
  input_ids, pixel_values, grid_thws = model.preprocess_inputs(messages=messages, add_generation_prompt=True, enable_thinking=enable_thinking)
174
  except Exception as e:
175
+ gr.Warning(f"Error during input preprocessing: {e}")
176
+ chatbot.pop(-1)
177
  yield chatbot
178
  return
179
 
 
189
  "eos_token_id": model.text_tokenizer.eos_token_id,
190
  "pad_token_id": model.text_tokenizer.pad_token_id,
191
  "streamer": streamer,
192
+ "use_cache": True,
193
+ "enable_thinking": enable_thinking,
194
+ "enable_thinking_budget": enable_thinking_budget,
195
+ "thinking_budget": thinking_budget
196
  }
197
 
198
  with torch.inference_mode():
 
219
  chatbot[-1][1] = formatted_response
220
  yield chatbot # Yield the final, formatted response
221
 
222
+ logger.info("\n[OVIS_CONV_START]")
223
+ [print(f'Q{i}:\n {request}\nA{i}:\n {answer}\n') for i, (request, answer) in enumerate(chatbot, 1)]
 
 
224
  logger.info("[OVIS_CONV_END]")
225
 
226
 
 
 
 
227
  # --- UI Helper Functions ---
228
  def toggle_media_input(choice: str) -> Tuple:
229
  """Switches visibility between Image/Video inputs and their corresponding examples."""
 
234
 
235
 
236
  # --- Build Gradio Application ---
 
237
  def build_demo(model_path: str):
238
  """Builds the Gradio user interface for the model."""
239
  global model, streamer
 
247
  ).to(device).eval()
248
 
249
  text_tokenizer = model.text_tokenizer
250
+ streamer = MyTextIteratorStreamer(text_tokenizer, skip_prompt=True, skip_special_tokens=True)
251
 
252
  print("Model loaded successfully.")
253
 
 
273
  <center><font size=3><b>Ovis</b> has been open-sourced on <a href='https://huggingface.co/{model_path}'>😊 Huggingface</a> and <a href='https://github.com/AIDC-AI/Ovis'>🌟 GitHub</a>. If you find Ovis useful, a like❤️ or a star🌟 would be appreciated.</font></center>
274
  """
275
 
276
+ # --- START: Slider synchronization logic functions ---
277
+ def adjust_max_tokens(thinking_budget_val: int, max_new_tokens_val: int) -> gr.Slider:
278
+ """Adjusts max_new_tokens to be at least thinking_budget + 128."""
279
+ new_max_tokens = max(max_new_tokens_val, thinking_budget_val + 128)
280
+ return gr.update(value=new_max_tokens)
281
+
282
+ def adjust_thinking_budget(max_new_tokens_val: int, thinking_budget_val: int) -> gr.Slider:
283
+ """Adjusts thinking_budget to be at most max_new_tokens - 128."""
284
+ new_thinking_budget = min(thinking_budget_val, max_new_tokens_val - 128)
285
+ return gr.update(value=new_thinking_budget)
286
+ # --- END: Slider synchronization logic functions ---
287
+
288
  prompt_input = gr.Textbox(label="Prompt", placeholder="Enter your text here and press ENTER", lines=1, container=False)
289
  with gr.Blocks(theme=gr.themes.Ocean()) as demo:
290
  gr.HTML(html_header)
291
+ gr.Markdown("Note: The Thinking Budget mechanism is enabled only when `Deep Thinking` and `Thinking Budget` are both checked. Could tune down `Thinking Budget` for faster generation in `Deep Thinking` mode.")
292
 
293
  with gr.Row():
294
  with gr.Column(scale=4):
 
298
 
299
  with gr.Accordion("Generation Settings", open=True):
300
  do_sample = gr.Checkbox(label="Enable Sampling (Do Sample)", value=True)
301
+ enable_thinking = gr.Checkbox(label="Enable Deep Thinking", value=True)
302
+ enable_thinking_budget = gr.Checkbox(label="Enable Thinking Budget", value=True)
303
+ max_new_tokens = gr.Slider(minimum=256, maximum=4096, value=2048, step=32, label="Max New Tokens")
304
+ thinking_budget = gr.Slider(minimum=128, maximum=3968, value=1024, step=32, label="Thinking Budget")
305
 
306
  with gr.Column(visible=True) as image_examples_col:
307
  gr.Examples(
 
325
  generate_btn = gr.Button("Send", variant="primary")
326
  clear_btn = gr.Button("Clear", variant="secondary")
327
 
328
+ # --- START: Event Handlers for UI Elements ---
329
+
330
  input_type_radio.change(
331
  fn=toggle_media_input,
332
  inputs=input_type_radio,
333
  outputs=[image_input, video_input, image_examples_col, video_examples_col]
334
  )
335
 
336
+ # Event handlers for coupled sliders
337
+ thinking_budget.release(
338
+ fn=adjust_max_tokens,
339
+ inputs=[thinking_budget, max_new_tokens],
340
+ outputs=[max_new_tokens]
341
+ )
342
+ max_new_tokens.release(
343
+ fn=adjust_thinking_budget,
344
+ inputs=[max_new_tokens, thinking_budget],
345
+ outputs=[thinking_budget]
346
+ )
347
+
348
+ # MODIFICATION: Update run_inputs to include new controls
349
+ run_inputs = [chatbot, image_input, video_input, do_sample, max_new_tokens, enable_thinking, enable_thinking_budget, thinking_budget]
350
 
351
  generat_click_event = generate_btn.click(submit_chat, [chatbot, prompt_input], [chatbot, prompt_input]).then(run_inference, run_inputs, chatbot)
352
  submit_event = prompt_input.submit(submit_chat, [chatbot, prompt_input], [chatbot, prompt_input]).then(run_inference, run_inputs, chatbot)
353
 
354
+ # MODIFICATION: Update clear button to reset new controls
355
+ # clear_btn.click(
356
+ # fn=lambda: ([], None, None, "", "Image", True, 2048, True, True, 1024),
357
+ # outputs=[chatbot, image_input, video_input, prompt_input, input_type_radio, do_sample, max_new_tokens, enable_thinking, enable_thinking_budget, thinking_budget]
358
+ # ).then(
359
+ # fn=toggle_media_input,
360
+ # inputs=input_type_radio,
361
+ # outputs=[image_input, video_input, image_examples_col, video_examples_col]
362
+ # )
363
  clear_btn.click(
364
+ fn=lambda: (list(), None, None, ""),
365
+ outputs=[chatbot, image_input, video_input, prompt_input]
 
 
 
 
366
  )
367
+ # --- END: Event Handlers for UI Elements ---
368
 
369
  return demo
370
 
371
+
372
  # --- Main Execution Block ---
373
  # def parse_args():
374
  # parser = argparse.ArgumentParser(description="Gradio interface for a single Multimodal Large Language Model.")