prithivMLmods commited on
Commit
0e5cdf9
·
verified ·
1 Parent(s): e817668

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -33
app.py CHANGED
@@ -62,7 +62,6 @@ model_x = AutoModelForVision2Seq.from_pretrained(
62
  torch_dtype=torch.float16
63
  ).to(device).eval()
64
 
65
-
66
  # Preprocessing functions for SmolDocling-256M
67
  def add_random_padding(image, min_percent=0.1, max_percent=0.10):
68
  """Add random padding to an image based on its size."""
@@ -110,18 +109,30 @@ def downsample_video(video_path):
110
  return frames
111
 
112
  # Dolphin-specific functions
113
- def model_chat(prompt, image):
114
- """Use Dolphin model for inference."""
115
  processor = processor_k
116
  model = model_k
117
  device = "cuda" if torch.cuda.is_available() else "cpu"
118
- inputs = processor(image, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
119
  pixel_values = inputs.pixel_values.half()
 
 
120
  prompt_inputs = processor.tokenizer(
121
- f"<s>{prompt} <Answer/>",
122
  add_special_tokens=False,
123
- return_tensors="pt"
 
124
  ).to(device)
 
125
  outputs = model.generate(
126
  pixel_values=pixel_values,
127
  decoder_input_ids=prompt_inputs.input_ids,
@@ -137,20 +148,48 @@ def model_chat(prompt, image):
137
  num_beams=1,
138
  repetition_penalty=1.1
139
  )
140
- sequence = processor.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)[0]
141
- cleaned = sequence.replace(f"<s>{prompt} <Answer/>", "").replace("<pad>", "").replace("</s>", "").strip()
142
- return cleaned
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  def process_elements(layout_results, image):
145
  """Parse layout results and extract elements from the image."""
146
- # Placeholder parsing logic based on expected Dolphin output
147
- # Assuming layout_results is a string like "[(x1,y1,x2,y2,label), ...]"
148
  try:
149
  elements = ast.literal_eval(layout_results)
150
  except:
151
- elements = [] # Fallback if parsing fails
152
 
153
- recognition_results = []
 
 
154
  reading_order = 0
155
 
156
  for bbox, label in elements:
@@ -158,27 +197,21 @@ def process_elements(layout_results, image):
158
  x1, y1, x2, y2 = map(int, bbox)
159
  cropped = image.crop((x1, y1, x2, y2))
160
  if cropped.size[0] > 0 and cropped.size[1] > 0:
 
 
 
 
 
 
161
  if label == "text":
162
- text = model_chat("Read text in the image.", cropped)
163
- recognition_results.append({
164
- "label": label,
165
- "bbox": [x1, y1, x2, y2],
166
- "text": text.strip(),
167
- "reading_order": reading_order
168
- })
169
  elif label == "table":
170
- table_text = model_chat("Parse the table in the image.", cropped)
171
- recognition_results.append({
172
- "label": label,
173
- "bbox": [x1, y1, x2, y2],
174
- "text": table_text.strip(),
175
- "reading_order": reading_order
176
- })
177
  elif label == "figure":
178
- recognition_results.append({
179
  "label": label,
180
  "bbox": [x1, y1, x2, y2],
181
- "text": "[Figure]", # Placeholder for figure content
182
  "reading_order": reading_order
183
  })
184
  reading_order += 1
@@ -186,12 +219,23 @@ def process_elements(layout_results, image):
186
  print(f"Error processing element: {e}")
187
  continue
188
 
 
 
 
 
 
 
 
 
 
 
 
189
  return recognition_results
190
 
191
  def generate_markdown(recognition_results):
192
  """Generate markdown from extracted elements."""
193
  markdown = ""
194
- for element in sorted(recognition_results, key=lambda x: x["reading_order"]):
195
  if element["label"] == "text":
196
  markdown += f"{element['text']}\n\n"
197
  elif element["label"] == "table":
@@ -222,7 +266,6 @@ def generate_image(model_name: str, text: str, image: Image.Image,
222
  markdown_content = process_image_with_dolphin(image)
223
  yield markdown_content
224
  else:
225
- # Existing logic for other models
226
  if model_name == "olmOCR-7B-0225-preview":
227
  processor = processor_m
228
  model = model_m
@@ -309,7 +352,6 @@ def generate_video(model_name: str, text: str, video_path: str,
309
  combined_markdown = "\n\n".join(markdown_contents)
310
  yield combined_markdown
311
  else:
312
- # Existing logic for other models
313
  if model_name == "olmOCR-7B-0225-preview":
314
  processor = processor_m
315
  model = model_m
@@ -401,7 +443,7 @@ css = """
401
 
402
  # Create the Gradio Interface
403
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
404
- gr.Markdown("# **[Core OCR](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
405
  with gr.Row():
406
  with gr.Column():
407
  with gr.Tabs():
 
62
  torch_dtype=torch.float16
63
  ).to(device).eval()
64
 
 
65
  # Preprocessing functions for SmolDocling-256M
66
  def add_random_padding(image, min_percent=0.1, max_percent=0.10):
67
  """Add random padding to an image based on its size."""
 
109
  return frames
110
 
111
  # Dolphin-specific functions
112
+ def model_chat(prompt, image, is_batch=False):
113
+ """Use Dolphin model for inference, supporting both single and batch processing."""
114
  processor = processor_k
115
  model = model_k
116
  device = "cuda" if torch.cuda.is_available() else "cpu"
117
+
118
+ if not is_batch:
119
+ images = [image]
120
+ prompts = [prompt]
121
+ else:
122
+ images = image
123
+ prompts = prompt if isinstance(prompt, list) else [prompt] * len(images)
124
+
125
+ inputs = processor(images, return_tensors="pt", padding=True).to(device)
126
  pixel_values = inputs.pixel_values.half()
127
+
128
+ prompts = [f"<s>{p} <Answer/>" for p in prompts]
129
  prompt_inputs = processor.tokenizer(
130
+ prompts,
131
  add_special_tokens=False,
132
+ return_tensors="pt",
133
+ padding=True
134
  ).to(device)
135
+
136
  outputs = model.generate(
137
  pixel_values=pixel_values,
138
  decoder_input_ids=prompt_inputs.input_ids,
 
148
  num_beams=1,
149
  repetition_penalty=1.1
150
  )
151
+ sequences = processor.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)
152
+
153
+ results = []
154
+ for i, sequence in enumerate(sequences):
155
+ cleaned = sequence.replace(prompts[i], "").replace("<pad>", "").replace("</s>", "").strip()
156
+ results.append(cleaned)
157
+
158
+ return results[0] if not is_batch else results
159
+
160
+ def process_element_batch(elements, prompt, max_batch_size=16):
161
+ """Process a batch of elements with the same prompt."""
162
+ results = []
163
+ batch_size = min(len(elements), max_batch_size)
164
+
165
+ for i in range(0, len(elements), batch_size):
166
+ batch_elements = elements[i:i + batch_size]
167
+ crops_list = [elem["crop"] for elem in batch_elements]
168
+ prompts_list = [prompt] * len(crops_list)
169
+
170
+ batch_results = model_chat(prompts_list, crops_list, is_batch=True)
171
+
172
+ for j, result in enumerate(batch_results):
173
+ elem = batch_elements[j]
174
+ results.append({
175
+ "label": elem["label"],
176
+ "bbox": elem["bbox"],
177
+ "text": result.strip(),
178
+ "reading_order": elem["reading_order"],
179
+ })
180
+
181
+ return results
182
 
183
  def process_elements(layout_results, image):
184
  """Parse layout results and extract elements from the image."""
 
 
185
  try:
186
  elements = ast.literal_eval(layout_results)
187
  except:
188
+ elements = []
189
 
190
+ text_elements = []
191
+ table_elements = []
192
+ figure_results = []
193
  reading_order = 0
194
 
195
  for bbox, label in elements:
 
197
  x1, y1, x2, y2 = map(int, bbox)
198
  cropped = image.crop((x1, y1, x2, y2))
199
  if cropped.size[0] > 0 and cropped.size[1] > 0:
200
+ element_info = {
201
+ "crop": cropped,
202
+ "label": label,
203
+ "bbox": [x1, y1, x2, y2],
204
+ "reading_order": reading_order,
205
+ }
206
  if label == "text":
207
+ text_elements.append(element_info)
 
 
 
 
 
 
208
  elif label == "table":
209
+ table_elements.append(element_info)
 
 
 
 
 
 
210
  elif label == "figure":
211
+ figure_results.append({
212
  "label": label,
213
  "bbox": [x1, y1, x2, y2],
214
+ "text": "[Figure]",
215
  "reading_order": reading_order
216
  })
217
  reading_order += 1
 
219
  print(f"Error processing element: {e}")
220
  continue
221
 
222
+ recognition_results = figure_results.copy()
223
+
224
+ if text_elements:
225
+ text_results = process_element_batch(text_elements, "Read text in the image.")
226
+ recognition_results.extend(text_results)
227
+
228
+ if table_elements:
229
+ table_results = process_element_batch(table_elements, "Parse the table in the image.")
230
+ recognition_results.extend(table_results)
231
+
232
+ recognition_results.sort(key=lambda x: x["reading_order"])
233
  return recognition_results
234
 
235
  def generate_markdown(recognition_results):
236
  """Generate markdown from extracted elements."""
237
  markdown = ""
238
+ for element in recognition_results:
239
  if element["label"] == "text":
240
  markdown += f"{element['text']}\n\n"
241
  elif element["label"] == "table":
 
266
  markdown_content = process_image_with_dolphin(image)
267
  yield markdown_content
268
  else:
 
269
  if model_name == "olmOCR-7B-0225-preview":
270
  processor = processor_m
271
  model = model_m
 
352
  combined_markdown = "\n\n".join(markdown_contents)
353
  yield combined_markdown
354
  else:
 
355
  if model_name == "olmOCR-7B-0225-preview":
356
  processor = processor_m
357
  model = model_m
 
443
 
444
  # Create the Gradio Interface
445
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
446
+ gr.Markdown("# **[Docling-VLMs](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
447
  with gr.Row():
448
  with gr.Column():
449
  with gr.Tabs():