prithivMLmods commited on
Commit
8c1f8ea
·
verified ·
1 Parent(s): a8067dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -14
app.py CHANGED
@@ -15,11 +15,10 @@ from transformers import (
15
  from transformers import Qwen2_5_VLForConditionalGeneration
16
 
17
  # Helper Functions
18
-
19
  def progress_bar_html(label: str, primary_color: str = "#4B0082", secondary_color: str = "#9370DB") -> str:
20
  """
21
  Returns an HTML snippet for a thin animated progress bar with a label.
22
- Colors can be customized; default colors are used for Qwen2VL/AyaVision.
23
  """
24
  return f'''
25
  <div style="display: flex; align-items: center;">
@@ -36,7 +35,6 @@ def progress_bar_html(label: str, primary_color: str = "#4B0082", secondary_colo
36
  </style>
37
  '''
38
 
39
-
40
  def downsample_video(video_path):
41
  """
42
  Downsamples a video file by extracting 25 evenly spaced frames.
@@ -81,7 +79,7 @@ rolmocr_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
81
  # Main Inference Function
82
  @spaces.GPU
83
  def model_inference(input_dict, history, use_rolmocr=False):
84
- text = input_dict.get("text", "").strip()
85
  files = input_dict.get("files", [])
86
 
87
  if not text and not files:
@@ -121,6 +119,7 @@ def model_inference(input_dict, history, use_rolmocr=False):
121
  model = rolmocr_model if use_rolmocr else qwen_model
122
  model_name = "RolmOCR" if use_rolmocr else "Qwen2VL OCR"
123
 
 
124
  prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
125
  all_images = [item["image"] for item in content if item["type"] == "image"]
126
  inputs = processor(
@@ -130,31 +129,33 @@ def model_inference(input_dict, history, use_rolmocr=False):
130
  padding=True,
131
  ).to("cuda")
132
 
 
133
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
134
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
135
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
136
  thread.start()
137
 
138
  buffer = ""
139
- # Send initial progress bar
140
  yield progress_bar_html(f"Processing with {model_name}")
141
 
142
- # Stream generation
143
  for new_text in streamer:
144
  buffer += new_text
145
  buffer = buffer.replace("<|im_end|>", "")
146
  time.sleep(0.01)
147
  yield buffer
148
 
149
- # Ensure generation is complete
150
- thread.join()
151
-
152
- # Save the full response to response.txt
153
  try:
154
  with open("response.txt", "w", encoding="utf-8") as f:
155
- f.write(buffer)
156
  except Exception as e:
157
- yield f"Error saving response: {e}"
 
 
 
 
158
 
159
  # Gradio Interface
160
  examples = [
@@ -163,7 +164,6 @@ examples = [
163
  [{"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]}],
164
  ]
165
 
166
-
167
  demo = gr.ChatInterface(
168
  fn=model_inference,
169
  description="# **Multimodal OCR `@RolmOCR and Default Qwen2VL OCR`**",
@@ -180,4 +180,4 @@ demo = gr.ChatInterface(
180
  additional_inputs=[gr.Checkbox(label="Use RolmOCR", value=False, info="Check to use RolmOCR, uncheck to use Qwen2VL OCR")],
181
  )
182
 
183
- demo.launch(debug=True)
 
15
  from transformers import Qwen2_5_VLForConditionalGeneration
16
 
17
  # Helper Functions
 
18
  def progress_bar_html(label: str, primary_color: str = "#4B0082", secondary_color: str = "#9370DB") -> str:
19
  """
20
  Returns an HTML snippet for a thin animated progress bar with a label.
21
+ Colors can be customized; default colors are used for Qwen2VL/Aya-Vision.
22
  """
23
  return f'''
24
  <div style="display: flex; align-items: center;">
 
35
  </style>
36
  '''
37
 
 
38
  def downsample_video(video_path):
39
  """
40
  Downsamples a video file by extracting 25 evenly spaced frames.
 
79
  # Main Inference Function
80
  @spaces.GPU
81
  def model_inference(input_dict, history, use_rolmocr=False):
82
+ text = input_dict["text"].strip()
83
  files = input_dict.get("files", [])
84
 
85
  if not text and not files:
 
119
  model = rolmocr_model if use_rolmocr else qwen_model
120
  model_name = "RolmOCR" if use_rolmocr else "Qwen2VL OCR"
121
 
122
+ # Prepare prompt and inputs
123
  prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
124
  all_images = [item["image"] for item in content if item["type"] == "image"]
125
  inputs = processor(
 
129
  padding=True,
130
  ).to("cuda")
131
 
132
+ # Set up streaming
133
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
134
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
135
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
136
  thread.start()
137
 
138
  buffer = ""
 
139
  yield progress_bar_html(f"Processing with {model_name}")
140
 
141
+ # Stream tokens
142
  for new_text in streamer:
143
  buffer += new_text
144
  buffer = buffer.replace("<|im_end|>", "")
145
  time.sleep(0.01)
146
  yield buffer
147
 
148
+ # Once streaming is done, save to response.txt and yield final result
149
+ results = buffer.strip()
 
 
150
  try:
151
  with open("response.txt", "w", encoding="utf-8") as f:
152
+ f.write(results)
153
  except Exception as e:
154
+ yield f"Error writing to response.txt: {e}"
155
+ return
156
+
157
+ yield results
158
+ return
159
 
160
  # Gradio Interface
161
  examples = [
 
164
  [{"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]}],
165
  ]
166
 
 
167
  demo = gr.ChatInterface(
168
  fn=model_inference,
169
  description="# **Multimodal OCR `@RolmOCR and Default Qwen2VL OCR`**",
 
180
  additional_inputs=[gr.Checkbox(label="Use RolmOCR", value=False, info="Check to use RolmOCR, uncheck to use Qwen2VL OCR")],
181
  )
182
 
183
+ demo.launch(debug=True)