Haozhe commited on
Commit
cbc410e
·
1 Parent(s): 7ba5930
Files changed (1) hide show
  1. app.py +29 -16
app.py CHANGED
@@ -7,12 +7,12 @@ import pickle as pkl
7
  import re
8
  from PIL import Image
9
  import json
10
- # import spaces
11
  from serve_constants import html_header, bibtext, learn_more_markdown, tos_markdown
12
 
13
 
14
- MODEL_ID = "TIGER-Lab/PixelReasoner-RL-v1"
15
- example_image = "example_images/1.jpg"
16
  # "example_images/document.png"
17
  example_text = "What kind of restaurant is it?"
18
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True,
@@ -117,7 +117,7 @@ def parse_last_tool(output_text):
117
  tool_end = '</tool_call>'
118
  tool_start = '<tool_call>'
119
 
120
- # @spaces.GPU
121
  def model_inference(input_dict, history):
122
  text = input_dict["text"]
123
  files = input_dict["files"]
@@ -171,7 +171,8 @@ def model_inference(input_dict, history):
171
  })
172
 
173
  print(messages)
174
- complete_assistant_response_for_gradio = ""
 
175
  while True:
176
  """
177
  Generate and stream text
@@ -185,7 +186,7 @@ def model_inference(input_dict, history):
185
  ).to("cuda")
186
 
187
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=False)
188
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024, temperature=0.01, top_p=1.0, top_k=1)
189
  # import pdb; pdb.set_trace()
190
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
191
  thread.start()
@@ -196,20 +197,26 @@ def model_inference(input_dict, history):
196
  # yield buffer
197
  # print(buffer)
198
  current_model_output_segment = "" # Text generated in this specific model call
 
199
  for new_text_chunk in streamer:
200
  current_model_output_segment += new_text_chunk
201
  # Yield the sum of previously committed full response parts + current streaming segment
202
- yield complete_assistant_response_for_gradio + current_model_output_segment
203
- tmp = f"\n<b>Planning Visual Operations ...</b>\n\n"
204
- yield complete_assistant_response_for_gradio + current_model_output_segment.split(tool_start)[0] + tmp
 
 
 
 
205
  thread.join()
206
 
207
  # Process the full segment (e.g., remove <|im_end|>)
208
  processed_segment = current_model_output_segment.split("<|im_end|>", 1)[0] if "<|im_end|>" in current_model_output_segment else current_model_output_segment
209
 
210
  # Append this processed segment to the cumulative display string for Gradio
211
- complete_assistant_response_for_gradio += processed_segment + "\n\n"
212
- print(f"this one: {complete_assistant_response_for_gradio}")
 
213
  yield complete_assistant_response_for_gradio # Ensure the fully processed segment is yielded to Gradio
214
 
215
 
@@ -217,28 +224,34 @@ def model_inference(input_dict, history):
217
  qatext_for_tool_check = processed_segment
218
  require_tool = tool_end in qatext_for_tool_check and tool_start in qatext_for_tool_check
219
 
 
 
220
  if require_tool:
221
 
222
  tool_params = parse_last_tool(qatext_for_tool_check)
223
  tool_name = tool_params['name']
224
  tool_args = tool_params['arguments']
225
- complete_assistant_response_for_gradio += f"\n<b>Executing Visual Operations ...</b> @{tool_name}({tool_args})\n\n"
 
226
  yield complete_assistant_response_for_gradio # Update Gradio display
227
-
228
  video_flag = False
229
 
230
  raw_result = execute_tool(imagelist, rawimagelist, tool_args, tool_name, is_video=video_flag)
231
  print(raw_result)
232
  proc_img = raw_result
233
  all_images += [proc_img]
 
 
 
234
  new_piece = dict(role='user', content=[
235
  dict(type='text', text="\nHere is the cropped image (Image Size: {}x{}):".format(proc_img.size[0], proc_img.size[1])),
236
  dict(type='image', image=proc_img)
237
  ]
238
  )
239
  messages.append(new_piece)
240
-
241
- complete_assistant_response_for_gradio += f"\n<b>Analyzing Operation Result ...</b> @region(size={proc_img.size[0]}x{proc_img.size[1]})\n\n"
 
242
  yield complete_assistant_response_for_gradio # Update Gradio display
243
 
244
 
@@ -267,4 +280,4 @@ with gr.Blocks() as demo:
267
  gr.Markdown(learn_more_markdown)
268
  gr.Markdown(bibtext)
269
 
270
- demo.launch(debug=True)
 
7
  import re
8
  from PIL import Image
9
  import json
10
+ import spaces
11
  from serve_constants import html_header, bibtext, learn_more_markdown, tos_markdown
12
 
13
 
14
+ MODEL_ID = "/home/ma-user/work/haozhe/workspace/lmm-r1/toolckpts/pix17K0506wt-NormalizedPenalizedFixedReweightCont-256-lossvernone-samplevernone-fmtnone-group-n8-ml10000-lr10-sysvcot-8node/global_step24_hf_evalbest"
15
+ example_image = "/home/ma-user/work/haozhe/workspace/vlspaces/example_images/1.jpg"
16
  # "example_images/document.png"
17
  example_text = "What kind of restaurant is it?"
18
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True,
 
117
  tool_end = '</tool_call>'
118
  tool_start = '<tool_call>'
119
 
120
+ @spaces.GPU
121
  def model_inference(input_dict, history):
122
  text = input_dict["text"]
123
  files = input_dict["files"]
 
171
  })
172
 
173
  print(messages)
174
+ # complete_assistant_response_for_gradio = ""
175
+ complete_assistant_response_for_gradio = []
176
  while True:
177
  """
178
  Generate and stream text
 
186
  ).to("cuda")
187
 
188
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=False)
189
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024, temperature=0.1, top_p=0.95, top_k=50)
190
  # import pdb; pdb.set_trace()
191
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
192
  thread.start()
 
197
  # yield buffer
198
  # print(buffer)
199
  current_model_output_segment = "" # Text generated in this specific model call
200
+ toolflag = False
201
  for new_text_chunk in streamer:
202
  current_model_output_segment += new_text_chunk
203
  # Yield the sum of previously committed full response parts + current streaming segment
204
+ # yield complete_assistant_response_for_gradio + current_model_output_segment
205
+ if tool_start in current_model_output_segment:
206
+ toolflag = True
207
+ tmp = current_model_output_segment.split(tool_start)[0]
208
+ yield complete_assistant_response_for_gradio + [tmp+"\n\n<b>Planning Visual Operations ...</b>\n\n"]
209
+ if not toolflag:
210
+ yield complete_assistant_response_for_gradio + [current_model_output_segment]
211
  thread.join()
212
 
213
  # Process the full segment (e.g., remove <|im_end|>)
214
  processed_segment = current_model_output_segment.split("<|im_end|>", 1)[0] if "<|im_end|>" in current_model_output_segment else current_model_output_segment
215
 
216
  # Append this processed segment to the cumulative display string for Gradio
217
+ # complete_assistant_response_for_gradio += processed_segment + "\n\n"
218
+ complete_assistant_response_for_gradio += [processed_segment + "\n\n"]
219
+ # print(f"this one: {complete_assistant_response_for_gradio}")
220
  yield complete_assistant_response_for_gradio # Ensure the fully processed segment is yielded to Gradio
221
 
222
 
 
224
  qatext_for_tool_check = processed_segment
225
  require_tool = tool_end in qatext_for_tool_check and tool_start in qatext_for_tool_check
226
 
227
+ # print(f"Segment from model: \"{qatext_for_tool_check[:200]}...\", Requires tool: {require_tool}")
228
+
229
  if require_tool:
230
 
231
  tool_params = parse_last_tool(qatext_for_tool_check)
232
  tool_name = tool_params['name']
233
  tool_args = tool_params['arguments']
234
+ # complete_assistant_response_for_gradio += f"\n<b>Executing Visual Operations ...</b> @{tool_name}({tool_args})\n\n"
235
+ complete_assistant_response_for_gradio += [f"\n<b>Executing Visual Operations ...</b> @{tool_name}({tool_args})\n\n"]
236
  yield complete_assistant_response_for_gradio # Update Gradio display
 
237
  video_flag = False
238
 
239
  raw_result = execute_tool(imagelist, rawimagelist, tool_args, tool_name, is_video=video_flag)
240
  print(raw_result)
241
  proc_img = raw_result
242
  all_images += [proc_img]
243
+ # complete_assistant_response_for_gradio += [(proc_img, "Visual Operation Result")]
244
+ # yield complete_assistant_response_for_gradio # Update Gradio display
245
+
246
  new_piece = dict(role='user', content=[
247
  dict(type='text', text="\nHere is the cropped image (Image Size: {}x{}):".format(proc_img.size[0], proc_img.size[1])),
248
  dict(type='image', image=proc_img)
249
  ]
250
  )
251
  messages.append(new_piece)
252
+ # print(messages)
253
+ # complete_assistant_response_for_gradio += f"\n<b>Analyzing Operation Result ...</b> @region(size={proc_img.size[0]}x{proc_img.size[1]})\n\n"
254
+ complete_assistant_response_for_gradio += [f"\n<b>Analyzing Operation Result ...</b> @region(size={proc_img.size[0]}x{proc_img.size[1]})\n\n"]
255
  yield complete_assistant_response_for_gradio # Update Gradio display
256
 
257
 
 
280
  gr.Markdown(learn_more_markdown)
281
  gr.Markdown(bibtext)
282
 
283
+ demo.launch(debug=True, share=True)