prithivMLmods commited on
Commit
914bd4d
·
verified ·
1 Parent(s): 6e73997

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +265 -338
app.py CHANGED
@@ -5,12 +5,6 @@ import json
5
  import time
6
  import asyncio
7
  from threading import Thread
8
- import io
9
- import base64
10
- import re
11
- import ast
12
- import html
13
- from collections import namedtuple
14
 
15
  import gradio as gr
16
  import spaces
@@ -31,9 +25,13 @@ from transformers.image_utils import load_image
31
 
32
  from docling_core.types.doc import DoclingDocument, DocTagsDocument
33
 
 
 
 
 
34
  # Constants for text generation
35
- MAX_MAX_NEW_TOKENS = 4096
36
- DEFAULT_MAX_NEW_TOKENS = 2048
37
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
38
 
39
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -47,6 +45,15 @@ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
47
  torch_dtype=torch.float16
48
  ).to(device).eval()
49
 
 
 
 
 
 
 
 
 
 
50
  # Load SmolDocling-256M-preview
51
  MODEL_ID_X = "ds4sd/SmolDocling-256M-preview"
52
  processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
@@ -71,21 +78,6 @@ model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained(
71
  torch_dtype=torch.float16
72
  ).to(device).eval()
73
 
74
- #------------------------------------------------#
75
- # Load ByteDance's Dolphin (with specific implementation)
76
- print("Loading ByteDance/Dolphin model...")
77
- MODEL_ID_K = "ByteDance/Dolphin"
78
- processor_k = AutoProcessor.from_pretrained(MODEL_ID_K)
79
- model_k = VisionEncoderDecoderModel.from_pretrained(MODEL_ID_K)
80
- model_k.eval()
81
- model_k.to(device)
82
- if torch.cuda.is_available():
83
- model_k = model_k.half() # Use half-precision on GPU
84
- tokenizer_k = processor_k.tokenizer
85
- print("ByteDance/Dolphin model loaded.")
86
- #------------------------------------------------#
87
-
88
-
89
  # Preprocessing functions for SmolDocling-256M
90
  def add_random_padding(image, min_percent=0.1, max_percent=0.10):
91
  """Add random padding to an image based on its size."""
@@ -120,12 +112,7 @@ def downsample_video(video_path):
120
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
121
  fps = vidcap.get(cv2.CAP_PROP_FPS)
122
  frames = []
123
- # Take up to 10 frames
124
- num_frames_to_sample = min(10, total_frames)
125
- if num_frames_to_sample == 0:
126
- return []
127
- frame_indices = np.linspace(0, total_frames - 1, num_frames_to_sample, dtype=int)
128
-
129
  for i in frame_indices:
130
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
131
  success, image = vidcap.read()
@@ -137,194 +124,103 @@ def downsample_video(video_path):
137
  vidcap.release()
138
  return frames
139
 
140
- # ------------------- Dolphin Model Specific Helper Functions ------------------- #
141
-
142
- ImageDimensions = namedtuple("ImageDimensions", ["width", "height", "new_w", "new_h", "pad_w", "pad_h"])
143
-
144
- class MarkdownConverter:
145
- """Converts structured recognition results to a Markdown string."""
146
- def convert(self, elements):
147
- markdown_str = ""
148
- for elem in elements:
149
- label = elem["label"]
150
- text = elem["text"]
151
- if label == "fig":
152
- # Embed image as base64
153
- markdown_str += f"![figure](data:image/png;base64,{text})\n\n"
154
- elif label == "tab":
155
- markdown_str += f"### Table\n\n{text}\n\n"
156
- else: # text, title, head, foot, etc.
157
- markdown_str += f"{text}\n\n"
158
- return markdown_str.strip()
159
-
160
- def prepare_image_dolphin(pil_image, target_size=1024):
161
- """Pads a PIL image to a square, returning a cv2 image and dimensions."""
162
- image = np.array(pil_image.convert('RGB'))
163
- h, w, _ = image.shape
164
- if h > w:
165
- new_h, new_w = target_size, int(w * target_size / h)
166
- else:
167
- new_h, new_w = int(h * target_size / w), target_size
168
-
169
- resized_image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
170
-
171
- pad_w = (target_size - new_w) // 2
172
- pad_h = (target_size - new_h) // 2
173
-
174
- padded_image = np.pad(resized_image, ((pad_h, pad_h), (pad_w, pad_w), (0, 0)), 'constant', constant_values=255)
175
- dims = ImageDimensions(w, h, new_w, new_h, pad_w, pad_h)
176
-
177
- return padded_image, dims
178
-
179
- def parse_layout_string_dolphin(layout_string):
180
- """Parses the model's layout string into a list of (bbox, label) tuples."""
181
- pattern = r'([a-zA-Z_]+)\(((?:\d+,){3}\d+)\)'
182
- matches = re.findall(pattern, layout_string)
183
- results = []
184
- for label, coords_str in matches:
185
- coords = tuple(map(int, coords_str.split(',')))
186
- results.append((coords, label))
187
- return results
188
-
189
- def process_coordinates_dolphin(bbox, padded_image, dims, previous_box):
190
- """Converts relative bbox coordinates to absolute pixel coordinates for cropping."""
191
- x1, y1, x2, y2 = bbox
192
-
193
- orig_x1 = int(x1 / 1024 * dims.new_w)
194
- orig_y1 = int(y1 / 1024 * dims.new_h)
195
- orig_x2 = int(x2 / 1024 * dims.new_w)
196
- orig_y2 = int(y2 / 1024 * dims.new_h)
197
-
198
- x1 = orig_x1 + dims.pad_w
199
- y1 = orig_y1 + dims.pad_h
200
- x2 = orig_x2 + dims.pad_w
201
- y2 = orig_y2 + dims.pad_h
202
-
203
- return x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, bbox
204
-
205
- @spaces.GPU
206
- def dolphin_model_chat(model, processor, prompt, image):
207
- """Core inference function for the Dolphin model, supports batching."""
208
- is_batch = isinstance(image, list)
209
-
210
- images = image if is_batch else [image]
211
- prompts = prompt if isinstance(prompt, list) else [prompt] * len(images)
212
-
213
- batch_inputs = processor(images, return_tensors="pt", padding=True)
214
- pixel_values = batch_inputs.pixel_values.to(device)
215
- if torch.cuda.is_available():
216
- pixel_values = pixel_values.half()
217
-
218
- prompts = [f"<s>{p} <Answer/>" for p in prompts]
219
- prompt_inputs = tokenizer_k(prompts, add_special_tokens=False, return_tensors="pt")
220
- prompt_ids = prompt_inputs.input_ids.to(device)
221
- attention_mask = prompt_inputs.attention_mask.to(device)
222
-
223
  outputs = model.generate(
224
  pixel_values=pixel_values,
225
- decoder_input_ids=prompt_ids,
226
- decoder_attention_mask=attention_mask,
 
227
  max_length=4096,
228
- pad_token_id=tokenizer_k.pad_token_id,
229
- eos_token_id=tokenizer_k.eos_token_id,
230
  use_cache=True,
231
- bad_words_ids=[[tokenizer_k.unk_token_id]],
232
  return_dict_in_generate=True,
 
 
 
233
  )
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
- sequences = tokenizer_k.batch_decode(outputs.sequences, skip_special_tokens=False)
 
236
 
237
- results = []
238
- for i, seq in enumerate(sequences):
239
- cleaned = seq.replace(prompts[i], "").replace("<pad>", "").replace("</s>", "").strip()
240
- results.append(cleaned)
241
-
242
- return results[0] if not is_batch else results
243
-
244
- @spaces.GPU
245
- def process_element_batch_dolphin(elements, prompt, model, processor, max_batch_size=16):
246
- """Processes a batch of cropped image elements with the same prompt."""
247
- results = []
248
- for i in range(0, len(elements), max_batch_size):
249
- batch_elements = elements[i:i+max_batch_size]
250
- crops_list = [elem["crop"] for elem in batch_elements]
251
- prompts_list = [prompt] * len(crops_list)
252
-
253
- batch_results = dolphin_model_chat(model, processor, prompts_list, crops_list)
254
-
255
- for j, result in enumerate(batch_results):
256
- elem = batch_elements[j]
257
- results.append({
258
- "label": elem["label"],
259
- "bbox": elem["bbox"],
260
- "text": result.strip(),
261
- "reading_order": elem["reading_order"],
262
- })
263
- return results
264
-
265
- @spaces.GPU
266
- def run_dolphin_image_pipeline(pil_image, model, processor):
267
- """Runs the full two-stage pipeline for a single image."""
268
- try:
269
- # Stage 1: Layout Analysis
270
- print("Dolphin: Running layout analysis...")
271
- layout_output = dolphin_model_chat(model, processor, "Parse the reading order of this document.", pil_image)
272
-
273
- # Stage 2: Element Recognition
274
- print("Dolphin: Parsing layout and processing elements...")
275
- padded_image, dims = prepare_image_dolphin(pil_image)
276
- layout_results = parse_layout_string_dolphin(layout_output)
277
-
278
- text_elements, table_elements, figure_results = [], [], []
279
- previous_box = None
280
- reading_order = 0
281
-
282
- for bbox, label in layout_results:
283
- try:
284
- x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = process_coordinates_dolphin(
285
- bbox, padded_image, dims, previous_box
286
- )
287
- cropped = padded_image[y1:y2, x1:x2]
288
-
289
- if cropped.size > 0:
290
- pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
291
- element_info = {"crop": pil_crop, "label": label, "bbox": [orig_x1, orig_y1, orig_x2, orig_y2], "reading_order": reading_order}
292
-
293
- if label == "fig":
294
- buffered = io.BytesIO()
295
- pil_crop.save(buffered, format="PNG")
296
- img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
297
- figure_results.append({"label": label, "bbox": element_info["bbox"], "text": img_base64, "reading_order": reading_order})
298
- elif label == "tab":
299
- table_elements.append(element_info)
300
- else:
301
- text_elements.append(element_info)
302
- reading_order += 1
303
- except Exception as e:
304
- print(f"Dolphin: Error processing element with label {label}: {e}")
305
- continue
306
-
307
- recognition_results = figure_results.copy()
308
- if text_elements:
309
- print(f"Dolphin: Recognizing {len(text_elements)} text element(s)...")
310
- recognition_results.extend(process_element_batch_dolphin(text_elements, "Read text in the image.", model, processor))
311
- if table_elements:
312
- print(f"Dolphin: Parsing {len(table_elements)} table(s)...")
313
- recognition_results.extend(process_element_batch_dolphin(table_elements, "Parse the table in the image.", model, processor))
314
-
315
- recognition_results.sort(key=lambda x: x.get("reading_order", 0))
316
-
317
- # Stage 3: Generate Markdown
318
- print("Dolphin: Generating final Markdown output...")
319
- converter = MarkdownConverter()
320
- markdown_output = converter.convert(recognition_results)
321
- return f"**Markdown Output (from Dolphin):**\n\n{markdown_output}"
322
- except Exception as e:
323
- print(f"Error during Dolphin pipeline: {e}")
324
- return f"An error occurred during the Dolphin processing pipeline: {e}"
325
-
326
- # ------------------- End of Dolphin Specific Functions ------------------- #
327
-
328
 
329
  @spaces.GPU
330
  def generate_image(model_name: str, text: str, image: Image.Image,
@@ -334,62 +230,82 @@ def generate_image(model_name: str, text: str, image: Image.Image,
334
  top_k: int = 50,
335
  repetition_penalty: float = 1.2):
336
  """Generate responses for image input using the selected model."""
337
- if image is None:
338
- yield "Please upload an image."
339
- return
340
-
341
- # --- Dolphin Specific Path (Non-streaming, multi-stage) ---
342
  if model_name == "ByteDance-s-Dolphin":
343
- yield run_dolphin_image_pipeline(image, model_k, processor_k)
344
- return
345
-
346
- # --- Generic Path for Other Models (Streaming) ---
347
- if model_name == "Nanonets-OCR-s":
348
- processor, model = processor_m, model_m
349
- elif model_name == "MonkeyOCR-Recognition":
350
- processor, model = processor_g, model_g
351
- elif model_name == "SmolDocling-256M-preview":
352
- processor, model = processor_x, model_x
353
  else:
354
- yield "Invalid model selected."
355
- return
356
-
357
- images = [image]
358
- if model_name == "SmolDocling-256M-preview":
359
- if "OTSL" in text or "code" in text:
360
- images = [add_random_padding(img) for img in images]
361
- if "OCR at text at" in text or "Identify element" in text or "formula" in text:
362
- text = normalize_values(text, target_max=500)
363
-
364
- messages = [{"role": "user", "content": [{"type": "image"}] * len(images) + [{"type": "text", "text": text}]}]
365
- prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
366
- inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
367
-
368
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
369
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty}
370
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
371
- thread.start()
372
-
373
- buffer = ""
374
- full_output = ""
375
- for new_text in streamer:
376
- full_output += new_text
377
- buffer += new_text.replace("<|im_end|>", "")
378
- yield buffer
379
-
380
- if model_name == "SmolDocling-256M-preview":
381
- cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
382
- if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
383
- if "<chart>" in cleaned_output:
384
- cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
385
- cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
386
- doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
387
- doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
388
- markdown_output = doc.export_to_markdown()
389
- yield f"**MD Output:**\n\n{markdown_output}"
390
  else:
391
- yield cleaned_output
 
392
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
 
394
  @spaces.GPU
395
  def generate_video(model_name: str, text: str, video_path: str,
@@ -399,76 +315,88 @@ def generate_video(model_name: str, text: str, video_path: str,
399
  top_k: int = 50,
400
  repetition_penalty: float = 1.2):
401
  """Generate responses for video input using the selected model."""
402
- if video_path is None:
403
- yield "Please upload a video."
404
- return
405
-
406
- frames_with_ts = downsample_video(video_path)
407
- if not frames_with_ts:
408
- yield "Could not extract frames from the video."
409
- return
410
- images = [frame for frame, _ in frames_with_ts]
411
- timestamps = [ts for _, ts in frames_with_ts]
412
-
413
- # --- Dolphin Specific Path (Batch processing frames) ---
414
  if model_name == "ByteDance-s-Dolphin":
415
- if not text:
416
- yield "Please provide a query for the video analysis (e.g., 'Describe what you see')."
417
  return
418
- prompts = [text] * len(images)
419
- yield "Analyzing video frames with Dolphin... (this may take a moment)"
420
- results = dolphin_model_chat(model_k, processor_k, prompts, images)
421
- full_output = "### Dolphin Video Analysis (per-frame)\n\n"
422
- for i, res in enumerate(results):
423
- full_output += f"**Frame at {timestamps[i]:.2f}s:**\n{res.strip()}\n\n---\n"
424
- yield full_output
425
- return
426
-
427
- # --- Generic Path for Other Models (Streaming) ---
428
- if model_name == "Nanonets-OCR-s":
429
- processor, model = processor_m, model_m
430
- elif model_name == "MonkeyOCR-Recognition":
431
- processor, model = processor_g, model_g
432
- elif model_name == "SmolDocling-256M-preview":
433
- processor, model = processor_x, model_x
434
  else:
435
- yield "Invalid model selected."
436
- return
437
-
438
- if model_name == "SmolDocling-256M-preview":
439
- if "OTSL" in text or "code" in text:
440
- images = [add_random_padding(img) for img in images]
441
- if "OCR at text at" in text or "Identify element" in text or "formula" in text:
442
- text = normalize_values(text, target_max=500)
443
-
444
- messages = [{"role": "user", "content": [{"type": "image"}] * len(images) + [{"type": "text", "text": text}]}]
445
- prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
446
- inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
447
-
448
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
449
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty}
450
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
451
- thread.start()
452
-
453
- buffer = ""
454
- full_output = ""
455
- for new_text in streamer:
456
- full_output += new_text
457
- buffer += new_text.replace("<|im_end|>", "")
458
- yield buffer
459
-
460
- if model_name == "SmolDocling-256M-preview":
461
- cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
462
- if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
463
- if "<chart>" in cleaned_output:
464
- cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
465
- cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
466
- doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
467
- doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
468
- markdown_output = doc.export_to_markdown()
469
- yield f"**MD Output:**\n\n{markdown_output}"
470
  else:
471
- yield cleaned_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
 
473
  # Define examples for image and video inference
474
  image_examples = [
@@ -495,17 +423,11 @@ css = """
495
  # Create the Gradio Interface
496
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
497
  gr.Markdown("# **[Core OCR](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
498
- gr.Markdown("A multi-model OCR and Document AI interface. Select 'ByteDance-s-Dolphin' for advanced, two-stage document layout analysis on images.")
499
  with gr.Row():
500
  with gr.Column():
501
- model_choice = gr.Radio(
502
- choices=["Nanonets-OCR-s", "SmolDocling-256M-preview", "MonkeyOCR-Recognition", "ByteDance-s-Dolphin"],
503
- label="Select Model",
504
- value="Nanonets-OCR-s"
505
- )
506
  with gr.Tabs():
507
  with gr.TabItem("Image Inference"):
508
- image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here... For Dolphin, leave blank to run full document analysis or ask a question about the image.")
509
  image_upload = gr.Image(type="pil", label="Image")
510
  image_submit = gr.Button("Submit", elem_classes="submit-btn")
511
  gr.Examples(
@@ -520,15 +442,20 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
520
  examples=video_examples,
521
  inputs=[video_query, video_upload]
522
  )
523
- with gr.Accordion("Advanced options (for streaming models)", open=False):
524
  max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
525
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
526
  top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
527
  top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
528
  repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
529
  with gr.Column():
530
- output = gr.Markdown(label="Output", interactive=False)
531
-
 
 
 
 
 
532
  image_submit.click(
533
  fn=generate_image,
534
  inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
@@ -541,4 +468,4 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
541
  )
542
 
543
  if __name__ == "__main__":
544
- demo.queue(max_size=30).launch(share=True)
 
5
  import time
6
  import asyncio
7
  from threading import Thread
 
 
 
 
 
 
8
 
9
  import gradio as gr
10
  import spaces
 
25
 
26
  from docling_core.types.doc import DoclingDocument, DocTagsDocument
27
 
28
+ import re
29
+ import ast
30
+ import html
31
+
32
  # Constants for text generation
33
+ MAX_MAX_NEW_TOKENS = 2048
34
+ DEFAULT_MAX_NEW_TOKENS = 1024
35
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
36
 
37
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
45
  torch_dtype=torch.float16
46
  ).to(device).eval()
47
 
48
+ # Load ByteDance's Dolphin
49
+ MODEL_ID_K = "ByteDance/Dolphin"
50
+ processor_k = AutoProcessor.from_pretrained(MODEL_ID_K, trust_remote_code=True)
51
+ model_k = VisionEncoderDecoderModel.from_pretrained(
52
+ MODEL_ID_K,
53
+ trust_remote_code=True,
54
+ torch_dtype=torch.float16
55
+ ).to(device).eval()
56
+
57
  # Load SmolDocling-256M-preview
58
  MODEL_ID_X = "ds4sd/SmolDocling-256M-preview"
59
  processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
 
78
  torch_dtype=torch.float16
79
  ).to(device).eval()
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  # Preprocessing functions for SmolDocling-256M
82
  def add_random_padding(image, min_percent=0.1, max_percent=0.10):
83
  """Add random padding to an image based on its size."""
 
112
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
113
  fps = vidcap.get(cv2.CAP_PROP_FPS)
114
  frames = []
115
+ frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
 
 
 
 
 
116
  for i in frame_indices:
117
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
118
  success, image = vidcap.read()
 
124
  vidcap.release()
125
  return frames
126
 
127
+ # Dolphin-specific functions
128
+ def model_chat(prompt, image):
129
+ """Use Dolphin model for inference."""
130
+ processor = processor_k
131
+ model = model_k
132
+ device = "cuda" if torch.cuda.is_available() else "cpu"
133
+ inputs = processor(image, return_tensors="pt").to(device)
134
+ pixel_values = inputs.pixel_values.half()
135
+ prompt_inputs = processor.tokenizer(
136
+ f"<s>{prompt} <Answer/>",
137
+ add_special_tokens=False,
138
+ return_tensors="pt"
139
+ ).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  outputs = model.generate(
141
  pixel_values=pixel_values,
142
+ decoder_input_ids=prompt_inputs.input_ids,
143
+ decoder_attention_mask=prompt_inputs.attention_mask,
144
+ min_length=1,
145
  max_length=4096,
146
+ pad_token_id=processor.tokenizer.pad_token_id,
147
+ eos_token_id=processor.tokenizer.eos_token_id,
148
  use_cache=True,
149
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
150
  return_dict_in_generate=True,
151
+ do_sample=False,
152
+ num_beams=1,
153
+ repetition_penalty=1.1
154
  )
155
+ sequence = processor.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)[0]
156
+ cleaned = sequence.replace(f"<s>{prompt} <Answer/>", "").replace("<pad>", "").replace("</s>", "").strip()
157
+ return cleaned
158
+
159
+ def process_elements(layout_results, image):
160
+ """Parse layout results and extract elements from the image."""
161
+ # Placeholder parsing logic based on expected Dolphin output
162
+ # Assuming layout_results is a string like "[(x1,y1,x2,y2,label), ...]"
163
+ try:
164
+ elements = ast.literal_eval(layout_results)
165
+ except:
166
+ elements = [] # Fallback if parsing fails
167
 
168
+ recognition_results = []
169
+ reading_order = 0
170
 
171
+ for bbox, label in elements:
172
+ try:
173
+ x1, y1, x2, y2 = map(int, bbox)
174
+ cropped = image.crop((x1, y1, x2, y2))
175
+ if cropped.size[0] > 0 and cropped.size[1] > 0:
176
+ if label == "text":
177
+ text = model_chat("Read text in the image.", cropped)
178
+ recognition_results.append({
179
+ "label": label,
180
+ "bbox": [x1, y1, x2, y2],
181
+ "text": text.strip(),
182
+ "reading_order": reading_order
183
+ })
184
+ elif label == "table":
185
+ table_text = model_chat("Parse the table in the image.", cropped)
186
+ recognition_results.append({
187
+ "label": label,
188
+ "bbox": [x1, y1, x2, y2],
189
+ "text": table_text.strip(),
190
+ "reading_order": reading_order
191
+ })
192
+ elif label == "figure":
193
+ recognition_results.append({
194
+ "label": label,
195
+ "bbox": [x1, y1, x2, y2],
196
+ "text": "[Figure]", # Placeholder for figure content
197
+ "reading_order": reading_order
198
+ })
199
+ reading_order += 1
200
+ except Exception as e:
201
+ print(f"Error processing element: {e}")
202
+ continue
203
+
204
+ return recognition_results
205
+
206
+ def generate_markdown(recognition_results):
207
+ """Generate markdown from extracted elements."""
208
+ markdown = ""
209
+ for element in sorted(recognition_results, key=lambda x: x["reading_order"]):
210
+ if element["label"] == "text":
211
+ markdown += f"{element['text']}\n\n"
212
+ elif element["label"] == "table":
213
+ markdown += f"**Table:**\n{element['text']}\n\n"
214
+ elif element["label"] == "figure":
215
+ markdown += f"{element['text']}\n\n"
216
+ return markdown.strip()
217
+
218
+ def process_image_with_dolphin(image):
219
+ """Process a single image with Dolphin model."""
220
+ layout_output = model_chat("Parse the reading order of this document.", image)
221
+ elements = process_elements(layout_output, image)
222
+ markdown_content = generate_markdown(elements)
223
+ return markdown_content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
  @spaces.GPU
226
  def generate_image(model_name: str, text: str, image: Image.Image,
 
230
  top_k: int = 50,
231
  repetition_penalty: float = 1.2):
232
  """Generate responses for image input using the selected model."""
 
 
 
 
 
233
  if model_name == "ByteDance-s-Dolphin":
234
+ if image is None:
235
+ yield "Please upload an image."
236
+ return
237
+ markdown_content = process_image_with_dolphin(image)
238
+ yield markdown_content
 
 
 
 
 
239
  else:
240
+ # Existing logic for other models
241
+ if model_name == "Nanonets-OCR-s":
242
+ processor = processor_m
243
+ model = model_m
244
+ elif model_name == "MonkeyOCR-Recognition":
245
+ processor = processor_g
246
+ model = model_g
247
+ elif model_name == "SmolDocling-256M-preview":
248
+ processor = processor_x
249
+ model = model_x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  else:
251
+ yield "Invalid model selected."
252
+ return
253
 
254
+ if image is None:
255
+ yield "Please upload an image."
256
+ return
257
+
258
+ images = [image]
259
+
260
+ if model_name == "SmolDocling-256M-preview":
261
+ if "OTSL" in text or "code" in text:
262
+ images = [add_random_padding(img) for img in images]
263
+ if "OCR at text at" in text or "Identify element" in text or "formula" in text:
264
+ text = normalize_values(text, target_max=500)
265
+
266
+ messages = [
267
+ {
268
+ "role": "user",
269
+ "content": [{"type": "image"} for _ in images] + [
270
+ {"type": "text", "text": text}
271
+ ]
272
+ }
273
+ ]
274
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
275
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
276
+
277
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
278
+ generation_kwargs = {
279
+ **inputs,
280
+ "streamer": streamer,
281
+ "max_new_tokens": max_new_tokens,
282
+ "temperature": temperature,
283
+ "top_p": top_p,
284
+ "top_k": top_k,
285
+ "repetition_penalty": repetition_penalty,
286
+ }
287
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
288
+ thread.start()
289
+
290
+ buffer = ""
291
+ full_output = ""
292
+ for new_text in streamer:
293
+ full_output += new_text
294
+ buffer += new_text.replace("<|im_end|>", "")
295
+ yield buffer
296
+
297
+ if model_name == "SmolDocling-256M-preview":
298
+ cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
299
+ if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
300
+ if "<chart>" in cleaned_output:
301
+ cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
302
+ cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
303
+ doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
304
+ doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
305
+ markdown_output = doc.export_to_markdown()
306
+ yield f"**MD Output:**\n\n{markdown_output}"
307
+ else:
308
+ yield cleaned_output
309
 
310
  @spaces.GPU
311
  def generate_video(model_name: str, text: str, video_path: str,
 
315
  top_k: int = 50,
316
  repetition_penalty: float = 1.2):
317
  """Generate responses for video input using the selected model."""
 
 
 
 
 
 
 
 
 
 
 
 
318
  if model_name == "ByteDance-s-Dolphin":
319
+ if video_path is None:
320
+ yield "Please upload a video."
321
  return
322
+ frames = downsample_video(video_path)
323
+ markdown_contents = []
324
+ for frame, _ in frames:
325
+ markdown_content = process_image_with_dolphin(frame)
326
+ markdown_contents.append(markdown_content)
327
+ combined_markdown = "\n\n".join(markdown_contents)
328
+ yield combined_markdown
 
 
 
 
 
 
 
 
 
329
  else:
330
+ # Existing logic for other models
331
+ if model_name == "Nanonets-OCR-s":
332
+ processor = processor_m
333
+ model = model_m
334
+ elif model_name == "MonkeyOCR-Recognition":
335
+ processor = processor_g
336
+ model = model_g
337
+ elif model_name == "SmolDocling-256M-preview":
338
+ processor = processor_x
339
+ model = model_x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  else:
341
+ yield "Invalid model selected."
342
+ return
343
+
344
+ if video_path is None:
345
+ yield "Please upload a video."
346
+ return
347
+
348
+ frames = downsample_video(video_path)
349
+ images = [frame for frame, _ in frames]
350
+
351
+ if model_name == "SmolDocling-256M-preview":
352
+ if "OTSL" in text or "code" in text:
353
+ images = [add_random_padding(img) for img in images]
354
+ if "OCR at text at" in text or "Identify element" in text or "formula" in text:
355
+ text = normalize_values(text, target_max=500)
356
+
357
+ messages = [
358
+ {
359
+ "role": "user",
360
+ "content": [{"type": "image"} for _ in images] + [
361
+ {"type": "text", "text": text}
362
+ ]
363
+ }
364
+ ]
365
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
366
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
367
+
368
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
369
+ generation_kwargs = {
370
+ **inputs,
371
+ "streamer": streamer,
372
+ "max_new_tokens": max_new_tokens,
373
+ "temperature": temperature,
374
+ "top_p": top_p,
375
+ "top_k": top_k,
376
+ "repetition_penalty": repetition_penalty,
377
+ }
378
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
379
+ thread.start()
380
+
381
+ buffer = ""
382
+ full_output = ""
383
+ for new_text in streamer:
384
+ full_output += new_text
385
+ buffer += new_text.replace("<|im_end|>", "")
386
+ yield buffer
387
+
388
+ if model_name == "SmolDocling-256M-preview":
389
+ cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
390
+ if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
391
+ if "<chart>" in cleaned_output:
392
+ cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
393
+ cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
394
+ doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
395
+ doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
396
+ markdown_output = doc.export_to_markdown()
397
+ yield f"**MD Output:**\n\n{markdown_output}"
398
+ else:
399
+ yield cleaned_output
400
 
401
  # Define examples for image and video inference
402
  image_examples = [
 
423
  # Create the Gradio Interface
424
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
425
  gr.Markdown("# **[Core OCR](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
 
426
  with gr.Row():
427
  with gr.Column():
 
 
 
 
 
428
  with gr.Tabs():
429
  with gr.TabItem("Image Inference"):
430
+ image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
431
  image_upload = gr.Image(type="pil", label="Image")
432
  image_submit = gr.Button("Submit", elem_classes="submit-btn")
433
  gr.Examples(
 
442
  examples=video_examples,
443
  inputs=[video_query, video_upload]
444
  )
445
+ with gr.Accordion("Advanced options", open=False):
446
  max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
447
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
448
  top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
449
  top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
450
  repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
451
  with gr.Column():
452
+ output = gr.Textbox(label="Output", interactive=False, lines=3, scale=2)
453
+ model_choice = gr.Radio(
454
+ choices=["Nanonets-OCR-s", "SmolDocling-256M-preview", "MonkeyOCR-Recognition", "ByteDance-s-Dolphin"],
455
+ label="Select Model",
456
+ value="Nanonets-OCR-s"
457
+ )
458
+
459
  image_submit.click(
460
  fn=generate_image,
461
  inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
 
468
  )
469
 
470
  if __name__ == "__main__":
471
+ demo.queue(max_size=30).launch(share=True, mcp_server=True, ssr_mode=False, show_error=True)