prithivMLmods commited on
Commit
6e73997
·
verified ·
1 Parent(s): 098731a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +338 -265
app.py CHANGED
@@ -5,6 +5,12 @@ import json
5
  import time
6
  import asyncio
7
  from threading import Thread
 
 
 
 
 
 
8
 
9
  import gradio as gr
10
  import spaces
@@ -25,13 +31,9 @@ from transformers.image_utils import load_image
25
 
26
  from docling_core.types.doc import DoclingDocument, DocTagsDocument
27
 
28
- import re
29
- import ast
30
- import html
31
-
32
  # Constants for text generation
33
- MAX_MAX_NEW_TOKENS = 2048
34
- DEFAULT_MAX_NEW_TOKENS = 1024
35
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
36
 
37
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -45,15 +47,6 @@ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
45
  torch_dtype=torch.float16
46
  ).to(device).eval()
47
 
48
- # Load ByteDance's Dolphin
49
- MODEL_ID_K = "ByteDance/Dolphin"
50
- processor_k = AutoProcessor.from_pretrained(MODEL_ID_K, trust_remote_code=True)
51
- model_k = VisionEncoderDecoderModel.from_pretrained(
52
- MODEL_ID_K,
53
- trust_remote_code=True,
54
- torch_dtype=torch.float16
55
- ).to(device).eval()
56
-
57
  # Load SmolDocling-256M-preview
58
  MODEL_ID_X = "ds4sd/SmolDocling-256M-preview"
59
  processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
@@ -78,6 +71,21 @@ model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained(
78
  torch_dtype=torch.float16
79
  ).to(device).eval()
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  # Preprocessing functions for SmolDocling-256M
82
  def add_random_padding(image, min_percent=0.1, max_percent=0.10):
83
  """Add random padding to an image based on its size."""
@@ -112,7 +120,12 @@ def downsample_video(video_path):
112
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
113
  fps = vidcap.get(cv2.CAP_PROP_FPS)
114
  frames = []
115
- frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
 
 
 
 
 
116
  for i in frame_indices:
117
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
118
  success, image = vidcap.read()
@@ -124,103 +137,194 @@ def downsample_video(video_path):
124
  vidcap.release()
125
  return frames
126
 
127
- # Dolphin-specific functions
128
- def model_chat(prompt, image):
129
- """Use Dolphin model for inference."""
130
- processor = processor_k
131
- model = model_k
132
- device = "cuda" if torch.cuda.is_available() else "cpu"
133
- inputs = processor(image, return_tensors="pt").to(device)
134
- pixel_values = inputs.pixel_values.half()
135
- prompt_inputs = processor.tokenizer(
136
- f"<s>{prompt} <Answer/>",
137
- add_special_tokens=False,
138
- return_tensors="pt"
139
- ).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  outputs = model.generate(
141
  pixel_values=pixel_values,
142
- decoder_input_ids=prompt_inputs.input_ids,
143
- decoder_attention_mask=prompt_inputs.attention_mask,
144
- min_length=1,
145
  max_length=4096,
146
- pad_token_id=processor.tokenizer.pad_token_id,
147
- eos_token_id=processor.tokenizer.eos_token_id,
148
  use_cache=True,
149
- bad_words_ids=[[processor.tokenizer.unk_token_id]],
150
  return_dict_in_generate=True,
151
- do_sample=False,
152
- num_beams=1,
153
- repetition_penalty=1.1
154
  )
155
- sequence = processor.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)[0]
156
- cleaned = sequence.replace(f"<s>{prompt} <Answer/>", "").replace("<pad>", "").replace("</s>", "").strip()
157
- return cleaned
158
-
159
- def process_elements(layout_results, image):
160
- """Parse layout results and extract elements from the image."""
161
- # Placeholder parsing logic based on expected Dolphin output
162
- # Assuming layout_results is a string like "[(x1,y1,x2,y2,label), ...]"
163
- try:
164
- elements = ast.literal_eval(layout_results)
165
- except:
166
- elements = [] # Fallback if parsing fails
167
 
168
- recognition_results = []
169
- reading_order = 0
170
 
171
- for bbox, label in elements:
172
- try:
173
- x1, y1, x2, y2 = map(int, bbox)
174
- cropped = image.crop((x1, y1, x2, y2))
175
- if cropped.size[0] > 0 and cropped.size[1] > 0:
176
- if label == "text":
177
- text = model_chat("Read text in the image.", cropped)
178
- recognition_results.append({
179
- "label": label,
180
- "bbox": [x1, y1, x2, y2],
181
- "text": text.strip(),
182
- "reading_order": reading_order
183
- })
184
- elif label == "table":
185
- table_text = model_chat("Parse the table in the image.", cropped)
186
- recognition_results.append({
187
- "label": label,
188
- "bbox": [x1, y1, x2, y2],
189
- "text": table_text.strip(),
190
- "reading_order": reading_order
191
- })
192
- elif label == "figure":
193
- recognition_results.append({
194
- "label": label,
195
- "bbox": [x1, y1, x2, y2],
196
- "text": "[Figure]", # Placeholder for figure content
197
- "reading_order": reading_order
198
- })
199
- reading_order += 1
200
- except Exception as e:
201
- print(f"Error processing element: {e}")
202
- continue
203
-
204
- return recognition_results
205
-
206
- def generate_markdown(recognition_results):
207
- """Generate markdown from extracted elements."""
208
- markdown = ""
209
- for element in sorted(recognition_results, key=lambda x: x["reading_order"]):
210
- if element["label"] == "text":
211
- markdown += f"{element['text']}\n\n"
212
- elif element["label"] == "table":
213
- markdown += f"**Table:**\n{element['text']}\n\n"
214
- elif element["label"] == "figure":
215
- markdown += f"{element['text']}\n\n"
216
- return markdown.strip()
217
-
218
- def process_image_with_dolphin(image):
219
- """Process a single image with Dolphin model."""
220
- layout_output = model_chat("Parse the reading order of this document.", image)
221
- elements = process_elements(layout_output, image)
222
- markdown_content = generate_markdown(elements)
223
- return markdown_content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
  @spaces.GPU
226
  def generate_image(model_name: str, text: str, image: Image.Image,
@@ -230,82 +334,62 @@ def generate_image(model_name: str, text: str, image: Image.Image,
230
  top_k: int = 50,
231
  repetition_penalty: float = 1.2):
232
  """Generate responses for image input using the selected model."""
 
 
 
 
 
233
  if model_name == "ByteDance-s-Dolphin":
234
- if image is None:
235
- yield "Please upload an image."
236
- return
237
- markdown_content = process_image_with_dolphin(image)
238
- yield markdown_content
 
 
 
 
 
239
  else:
240
- # Existing logic for other models
241
- if model_name == "Nanonets-OCR-s":
242
- processor = processor_m
243
- model = model_m
244
- elif model_name == "MonkeyOCR-Recognition":
245
- processor = processor_g
246
- model = model_g
247
- elif model_name == "SmolDocling-256M-preview":
248
- processor = processor_x
249
- model = model_x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  else:
251
- yield "Invalid model selected."
252
- return
253
 
254
- if image is None:
255
- yield "Please upload an image."
256
- return
257
-
258
- images = [image]
259
-
260
- if model_name == "SmolDocling-256M-preview":
261
- if "OTSL" in text or "code" in text:
262
- images = [add_random_padding(img) for img in images]
263
- if "OCR at text at" in text or "Identify element" in text or "formula" in text:
264
- text = normalize_values(text, target_max=500)
265
-
266
- messages = [
267
- {
268
- "role": "user",
269
- "content": [{"type": "image"} for _ in images] + [
270
- {"type": "text", "text": text}
271
- ]
272
- }
273
- ]
274
- prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
275
- inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
276
-
277
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
278
- generation_kwargs = {
279
- **inputs,
280
- "streamer": streamer,
281
- "max_new_tokens": max_new_tokens,
282
- "temperature": temperature,
283
- "top_p": top_p,
284
- "top_k": top_k,
285
- "repetition_penalty": repetition_penalty,
286
- }
287
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
288
- thread.start()
289
-
290
- buffer = ""
291
- full_output = ""
292
- for new_text in streamer:
293
- full_output += new_text
294
- buffer += new_text.replace("<|im_end|>", "")
295
- yield buffer
296
-
297
- if model_name == "SmolDocling-256M-preview":
298
- cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
299
- if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
300
- if "<chart>" in cleaned_output:
301
- cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
302
- cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
303
- doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
304
- doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
305
- markdown_output = doc.export_to_markdown()
306
- yield f"**MD Output:**\n\n{markdown_output}"
307
- else:
308
- yield cleaned_output
309
 
310
  @spaces.GPU
311
  def generate_video(model_name: str, text: str, video_path: str,
@@ -315,88 +399,76 @@ def generate_video(model_name: str, text: str, video_path: str,
315
  top_k: int = 50,
316
  repetition_penalty: float = 1.2):
317
  """Generate responses for video input using the selected model."""
 
 
 
 
 
 
 
 
 
 
 
 
318
  if model_name == "ByteDance-s-Dolphin":
319
- if video_path is None:
320
- yield "Please upload a video."
321
  return
322
- frames = downsample_video(video_path)
323
- markdown_contents = []
324
- for frame, _ in frames:
325
- markdown_content = process_image_with_dolphin(frame)
326
- markdown_contents.append(markdown_content)
327
- combined_markdown = "\n\n".join(markdown_contents)
328
- yield combined_markdown
 
 
 
 
 
 
 
 
 
329
  else:
330
- # Existing logic for other models
331
- if model_name == "Nanonets-OCR-s":
332
- processor = processor_m
333
- model = model_m
334
- elif model_name == "MonkeyOCR-Recognition":
335
- processor = processor_g
336
- model = model_g
337
- elif model_name == "SmolDocling-256M-preview":
338
- processor = processor_x
339
- model = model_x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  else:
341
- yield "Invalid model selected."
342
- return
343
-
344
- if video_path is None:
345
- yield "Please upload a video."
346
- return
347
-
348
- frames = downsample_video(video_path)
349
- images = [frame for frame, _ in frames]
350
-
351
- if model_name == "SmolDocling-256M-preview":
352
- if "OTSL" in text or "code" in text:
353
- images = [add_random_padding(img) for img in images]
354
- if "OCR at text at" in text or "Identify element" in text or "formula" in text:
355
- text = normalize_values(text, target_max=500)
356
-
357
- messages = [
358
- {
359
- "role": "user",
360
- "content": [{"type": "image"} for _ in images] + [
361
- {"type": "text", "text": text}
362
- ]
363
- }
364
- ]
365
- prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
366
- inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
367
-
368
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
369
- generation_kwargs = {
370
- **inputs,
371
- "streamer": streamer,
372
- "max_new_tokens": max_new_tokens,
373
- "temperature": temperature,
374
- "top_p": top_p,
375
- "top_k": top_k,
376
- "repetition_penalty": repetition_penalty,
377
- }
378
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
379
- thread.start()
380
-
381
- buffer = ""
382
- full_output = ""
383
- for new_text in streamer:
384
- full_output += new_text
385
- buffer += new_text.replace("<|im_end|>", "")
386
- yield buffer
387
-
388
- if model_name == "SmolDocling-256M-preview":
389
- cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
390
- if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
391
- if "<chart>" in cleaned_output:
392
- cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
393
- cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
394
- doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
395
- doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
396
- markdown_output = doc.export_to_markdown()
397
- yield f"**MD Output:**\n\n{markdown_output}"
398
- else:
399
- yield cleaned_output
400
 
401
  # Define examples for image and video inference
402
  image_examples = [
@@ -423,11 +495,17 @@ css = """
423
  # Create the Gradio Interface
424
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
425
  gr.Markdown("# **[Core OCR](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
 
426
  with gr.Row():
427
  with gr.Column():
 
 
 
 
 
428
  with gr.Tabs():
429
  with gr.TabItem("Image Inference"):
430
- image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
431
  image_upload = gr.Image(type="pil", label="Image")
432
  image_submit = gr.Button("Submit", elem_classes="submit-btn")
433
  gr.Examples(
@@ -442,20 +520,15 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
442
  examples=video_examples,
443
  inputs=[video_query, video_upload]
444
  )
445
- with gr.Accordion("Advanced options", open=False):
446
  max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
447
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
448
  top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
449
  top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
450
  repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
451
  with gr.Column():
452
- output = gr.Textbox(label="Output", interactive=False, lines=3, scale=2)
453
- model_choice = gr.Radio(
454
- choices=["Nanonets-OCR-s", "SmolDocling-256M-preview", "MonkeyOCR-Recognition", "ByteDance-s-Dolphin"],
455
- label="Select Model",
456
- value="Nanonets-OCR-s"
457
- )
458
-
459
  image_submit.click(
460
  fn=generate_image,
461
  inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
@@ -468,4 +541,4 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
468
  )
469
 
470
  if __name__ == "__main__":
471
- demo.queue(max_size=30).launch(share=True, mcp_server=True, ssr_mode=False, show_error=True)
 
5
  import time
6
  import asyncio
7
  from threading import Thread
8
+ import io
9
+ import base64
10
+ import re
11
+ import ast
12
+ import html
13
+ from collections import namedtuple
14
 
15
  import gradio as gr
16
  import spaces
 
31
 
32
  from docling_core.types.doc import DoclingDocument, DocTagsDocument
33
 
 
 
 
 
34
  # Constants for text generation
35
+ MAX_MAX_NEW_TOKENS = 4096
36
+ DEFAULT_MAX_NEW_TOKENS = 2048
37
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
38
 
39
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
47
  torch_dtype=torch.float16
48
  ).to(device).eval()
49
 
 
 
 
 
 
 
 
 
 
50
  # Load SmolDocling-256M-preview
51
  MODEL_ID_X = "ds4sd/SmolDocling-256M-preview"
52
  processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
 
71
  torch_dtype=torch.float16
72
  ).to(device).eval()
73
 
74
+ #------------------------------------------------#
75
+ # Load ByteDance's Dolphin (with specific implementation)
76
+ print("Loading ByteDance/Dolphin model...")
77
+ MODEL_ID_K = "ByteDance/Dolphin"
78
+ processor_k = AutoProcessor.from_pretrained(MODEL_ID_K)
79
+ model_k = VisionEncoderDecoderModel.from_pretrained(MODEL_ID_K)
80
+ model_k.eval()
81
+ model_k.to(device)
82
+ if torch.cuda.is_available():
83
+ model_k = model_k.half() # Use half-precision on GPU
84
+ tokenizer_k = processor_k.tokenizer
85
+ print("ByteDance/Dolphin model loaded.")
86
+ #------------------------------------------------#
87
+
88
+
89
  # Preprocessing functions for SmolDocling-256M
90
  def add_random_padding(image, min_percent=0.1, max_percent=0.10):
91
  """Add random padding to an image based on its size."""
 
120
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
121
  fps = vidcap.get(cv2.CAP_PROP_FPS)
122
  frames = []
123
+ # Take up to 10 frames
124
+ num_frames_to_sample = min(10, total_frames)
125
+ if num_frames_to_sample == 0:
126
+ return []
127
+ frame_indices = np.linspace(0, total_frames - 1, num_frames_to_sample, dtype=int)
128
+
129
  for i in frame_indices:
130
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
131
  success, image = vidcap.read()
 
137
  vidcap.release()
138
  return frames
139
 
140
+ # ------------------- Dolphin Model Specific Helper Functions ------------------- #
141
+
142
+ ImageDimensions = namedtuple("ImageDimensions", ["width", "height", "new_w", "new_h", "pad_w", "pad_h"])
143
+
144
+ class MarkdownConverter:
145
+ """Converts structured recognition results to a Markdown string."""
146
+ def convert(self, elements):
147
+ markdown_str = ""
148
+ for elem in elements:
149
+ label = elem["label"]
150
+ text = elem["text"]
151
+ if label == "fig":
152
+ # Embed image as base64
153
+ markdown_str += f"![figure](data:image/png;base64,{text})\n\n"
154
+ elif label == "tab":
155
+ markdown_str += f"### Table\n\n{text}\n\n"
156
+ else: # text, title, head, foot, etc.
157
+ markdown_str += f"{text}\n\n"
158
+ return markdown_str.strip()
159
+
160
+ def prepare_image_dolphin(pil_image, target_size=1024):
161
+ """Pads a PIL image to a square, returning a cv2 image and dimensions."""
162
+ image = np.array(pil_image.convert('RGB'))
163
+ h, w, _ = image.shape
164
+ if h > w:
165
+ new_h, new_w = target_size, int(w * target_size / h)
166
+ else:
167
+ new_h, new_w = int(h * target_size / w), target_size
168
+
169
+ resized_image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
170
+
171
+ pad_w = (target_size - new_w) // 2
172
+ pad_h = (target_size - new_h) // 2
173
+
174
+ padded_image = np.pad(resized_image, ((pad_h, pad_h), (pad_w, pad_w), (0, 0)), 'constant', constant_values=255)
175
+ dims = ImageDimensions(w, h, new_w, new_h, pad_w, pad_h)
176
+
177
+ return padded_image, dims
178
+
179
+ def parse_layout_string_dolphin(layout_string):
180
+ """Parses the model's layout string into a list of (bbox, label) tuples."""
181
+ pattern = r'([a-zA-Z_]+)\(((?:\d+,){3}\d+)\)'
182
+ matches = re.findall(pattern, layout_string)
183
+ results = []
184
+ for label, coords_str in matches:
185
+ coords = tuple(map(int, coords_str.split(',')))
186
+ results.append((coords, label))
187
+ return results
188
+
189
+ def process_coordinates_dolphin(bbox, padded_image, dims, previous_box):
190
+ """Converts relative bbox coordinates to absolute pixel coordinates for cropping."""
191
+ x1, y1, x2, y2 = bbox
192
+
193
+ orig_x1 = int(x1 / 1024 * dims.new_w)
194
+ orig_y1 = int(y1 / 1024 * dims.new_h)
195
+ orig_x2 = int(x2 / 1024 * dims.new_w)
196
+ orig_y2 = int(y2 / 1024 * dims.new_h)
197
+
198
+ x1 = orig_x1 + dims.pad_w
199
+ y1 = orig_y1 + dims.pad_h
200
+ x2 = orig_x2 + dims.pad_w
201
+ y2 = orig_y2 + dims.pad_h
202
+
203
+ return x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, bbox
204
+
205
+ @spaces.GPU
206
+ def dolphin_model_chat(model, processor, prompt, image):
207
+ """Core inference function for the Dolphin model, supports batching."""
208
+ is_batch = isinstance(image, list)
209
+
210
+ images = image if is_batch else [image]
211
+ prompts = prompt if isinstance(prompt, list) else [prompt] * len(images)
212
+
213
+ batch_inputs = processor(images, return_tensors="pt", padding=True)
214
+ pixel_values = batch_inputs.pixel_values.to(device)
215
+ if torch.cuda.is_available():
216
+ pixel_values = pixel_values.half()
217
+
218
+ prompts = [f"<s>{p} <Answer/>" for p in prompts]
219
+ prompt_inputs = tokenizer_k(prompts, add_special_tokens=False, return_tensors="pt")
220
+ prompt_ids = prompt_inputs.input_ids.to(device)
221
+ attention_mask = prompt_inputs.attention_mask.to(device)
222
+
223
  outputs = model.generate(
224
  pixel_values=pixel_values,
225
+ decoder_input_ids=prompt_ids,
226
+ decoder_attention_mask=attention_mask,
 
227
  max_length=4096,
228
+ pad_token_id=tokenizer_k.pad_token_id,
229
+ eos_token_id=tokenizer_k.eos_token_id,
230
  use_cache=True,
231
+ bad_words_ids=[[tokenizer_k.unk_token_id]],
232
  return_dict_in_generate=True,
 
 
 
233
  )
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
+ sequences = tokenizer_k.batch_decode(outputs.sequences, skip_special_tokens=False)
 
236
 
237
+ results = []
238
+ for i, seq in enumerate(sequences):
239
+ cleaned = seq.replace(prompts[i], "").replace("<pad>", "").replace("</s>", "").strip()
240
+ results.append(cleaned)
241
+
242
+ return results[0] if not is_batch else results
243
+
244
+ @spaces.GPU
245
+ def process_element_batch_dolphin(elements, prompt, model, processor, max_batch_size=16):
246
+ """Processes a batch of cropped image elements with the same prompt."""
247
+ results = []
248
+ for i in range(0, len(elements), max_batch_size):
249
+ batch_elements = elements[i:i+max_batch_size]
250
+ crops_list = [elem["crop"] for elem in batch_elements]
251
+ prompts_list = [prompt] * len(crops_list)
252
+
253
+ batch_results = dolphin_model_chat(model, processor, prompts_list, crops_list)
254
+
255
+ for j, result in enumerate(batch_results):
256
+ elem = batch_elements[j]
257
+ results.append({
258
+ "label": elem["label"],
259
+ "bbox": elem["bbox"],
260
+ "text": result.strip(),
261
+ "reading_order": elem["reading_order"],
262
+ })
263
+ return results
264
+
265
+ @spaces.GPU
266
+ def run_dolphin_image_pipeline(pil_image, model, processor):
267
+ """Runs the full two-stage pipeline for a single image."""
268
+ try:
269
+ # Stage 1: Layout Analysis
270
+ print("Dolphin: Running layout analysis...")
271
+ layout_output = dolphin_model_chat(model, processor, "Parse the reading order of this document.", pil_image)
272
+
273
+ # Stage 2: Element Recognition
274
+ print("Dolphin: Parsing layout and processing elements...")
275
+ padded_image, dims = prepare_image_dolphin(pil_image)
276
+ layout_results = parse_layout_string_dolphin(layout_output)
277
+
278
+ text_elements, table_elements, figure_results = [], [], []
279
+ previous_box = None
280
+ reading_order = 0
281
+
282
+ for bbox, label in layout_results:
283
+ try:
284
+ x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = process_coordinates_dolphin(
285
+ bbox, padded_image, dims, previous_box
286
+ )
287
+ cropped = padded_image[y1:y2, x1:x2]
288
+
289
+ if cropped.size > 0:
290
+ pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
291
+ element_info = {"crop": pil_crop, "label": label, "bbox": [orig_x1, orig_y1, orig_x2, orig_y2], "reading_order": reading_order}
292
+
293
+ if label == "fig":
294
+ buffered = io.BytesIO()
295
+ pil_crop.save(buffered, format="PNG")
296
+ img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
297
+ figure_results.append({"label": label, "bbox": element_info["bbox"], "text": img_base64, "reading_order": reading_order})
298
+ elif label == "tab":
299
+ table_elements.append(element_info)
300
+ else:
301
+ text_elements.append(element_info)
302
+ reading_order += 1
303
+ except Exception as e:
304
+ print(f"Dolphin: Error processing element with label {label}: {e}")
305
+ continue
306
+
307
+ recognition_results = figure_results.copy()
308
+ if text_elements:
309
+ print(f"Dolphin: Recognizing {len(text_elements)} text element(s)...")
310
+ recognition_results.extend(process_element_batch_dolphin(text_elements, "Read text in the image.", model, processor))
311
+ if table_elements:
312
+ print(f"Dolphin: Parsing {len(table_elements)} table(s)...")
313
+ recognition_results.extend(process_element_batch_dolphin(table_elements, "Parse the table in the image.", model, processor))
314
+
315
+ recognition_results.sort(key=lambda x: x.get("reading_order", 0))
316
+
317
+ # Stage 3: Generate Markdown
318
+ print("Dolphin: Generating final Markdown output...")
319
+ converter = MarkdownConverter()
320
+ markdown_output = converter.convert(recognition_results)
321
+ return f"**Markdown Output (from Dolphin):**\n\n{markdown_output}"
322
+ except Exception as e:
323
+ print(f"Error during Dolphin pipeline: {e}")
324
+ return f"An error occurred during the Dolphin processing pipeline: {e}"
325
+
326
+ # ------------------- End of Dolphin Specific Functions ------------------- #
327
+
328
 
329
  @spaces.GPU
330
  def generate_image(model_name: str, text: str, image: Image.Image,
 
334
  top_k: int = 50,
335
  repetition_penalty: float = 1.2):
336
  """Generate responses for image input using the selected model."""
337
+ if image is None:
338
+ yield "Please upload an image."
339
+ return
340
+
341
+ # --- Dolphin Specific Path (Non-streaming, multi-stage) ---
342
  if model_name == "ByteDance-s-Dolphin":
343
+ yield run_dolphin_image_pipeline(image, model_k, processor_k)
344
+ return
345
+
346
+ # --- Generic Path for Other Models (Streaming) ---
347
+ if model_name == "Nanonets-OCR-s":
348
+ processor, model = processor_m, model_m
349
+ elif model_name == "MonkeyOCR-Recognition":
350
+ processor, model = processor_g, model_g
351
+ elif model_name == "SmolDocling-256M-preview":
352
+ processor, model = processor_x, model_x
353
  else:
354
+ yield "Invalid model selected."
355
+ return
356
+
357
+ images = [image]
358
+ if model_name == "SmolDocling-256M-preview":
359
+ if "OTSL" in text or "code" in text:
360
+ images = [add_random_padding(img) for img in images]
361
+ if "OCR at text at" in text or "Identify element" in text or "formula" in text:
362
+ text = normalize_values(text, target_max=500)
363
+
364
+ messages = [{"role": "user", "content": [{"type": "image"}] * len(images) + [{"type": "text", "text": text}]}]
365
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
366
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
367
+
368
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
369
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty}
370
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
371
+ thread.start()
372
+
373
+ buffer = ""
374
+ full_output = ""
375
+ for new_text in streamer:
376
+ full_output += new_text
377
+ buffer += new_text.replace("<|im_end|>", "")
378
+ yield buffer
379
+
380
+ if model_name == "SmolDocling-256M-preview":
381
+ cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
382
+ if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
383
+ if "<chart>" in cleaned_output:
384
+ cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
385
+ cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
386
+ doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
387
+ doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
388
+ markdown_output = doc.export_to_markdown()
389
+ yield f"**MD Output:**\n\n{markdown_output}"
390
  else:
391
+ yield cleaned_output
 
392
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
 
394
  @spaces.GPU
395
  def generate_video(model_name: str, text: str, video_path: str,
 
399
  top_k: int = 50,
400
  repetition_penalty: float = 1.2):
401
  """Generate responses for video input using the selected model."""
402
+ if video_path is None:
403
+ yield "Please upload a video."
404
+ return
405
+
406
+ frames_with_ts = downsample_video(video_path)
407
+ if not frames_with_ts:
408
+ yield "Could not extract frames from the video."
409
+ return
410
+ images = [frame for frame, _ in frames_with_ts]
411
+ timestamps = [ts for _, ts in frames_with_ts]
412
+
413
+ # --- Dolphin Specific Path (Batch processing frames) ---
414
  if model_name == "ByteDance-s-Dolphin":
415
+ if not text:
416
+ yield "Please provide a query for the video analysis (e.g., 'Describe what you see')."
417
  return
418
+ prompts = [text] * len(images)
419
+ yield "Analyzing video frames with Dolphin... (this may take a moment)"
420
+ results = dolphin_model_chat(model_k, processor_k, prompts, images)
421
+ full_output = "### Dolphin Video Analysis (per-frame)\n\n"
422
+ for i, res in enumerate(results):
423
+ full_output += f"**Frame at {timestamps[i]:.2f}s:**\n{res.strip()}\n\n---\n"
424
+ yield full_output
425
+ return
426
+
427
+ # --- Generic Path for Other Models (Streaming) ---
428
+ if model_name == "Nanonets-OCR-s":
429
+ processor, model = processor_m, model_m
430
+ elif model_name == "MonkeyOCR-Recognition":
431
+ processor, model = processor_g, model_g
432
+ elif model_name == "SmolDocling-256M-preview":
433
+ processor, model = processor_x, model_x
434
  else:
435
+ yield "Invalid model selected."
436
+ return
437
+
438
+ if model_name == "SmolDocling-256M-preview":
439
+ if "OTSL" in text or "code" in text:
440
+ images = [add_random_padding(img) for img in images]
441
+ if "OCR at text at" in text or "Identify element" in text or "formula" in text:
442
+ text = normalize_values(text, target_max=500)
443
+
444
+ messages = [{"role": "user", "content": [{"type": "image"}] * len(images) + [{"type": "text", "text": text}]}]
445
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
446
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
447
+
448
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
449
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty}
450
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
451
+ thread.start()
452
+
453
+ buffer = ""
454
+ full_output = ""
455
+ for new_text in streamer:
456
+ full_output += new_text
457
+ buffer += new_text.replace("<|im_end|>", "")
458
+ yield buffer
459
+
460
+ if model_name == "SmolDocling-256M-preview":
461
+ cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
462
+ if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
463
+ if "<chart>" in cleaned_output:
464
+ cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
465
+ cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
466
+ doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
467
+ doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
468
+ markdown_output = doc.export_to_markdown()
469
+ yield f"**MD Output:**\n\n{markdown_output}"
470
  else:
471
+ yield cleaned_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
 
473
  # Define examples for image and video inference
474
  image_examples = [
 
495
  # Create the Gradio Interface
496
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
497
  gr.Markdown("# **[Core OCR](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
498
+ gr.Markdown("A multi-model OCR and Document AI interface. Select 'ByteDance-s-Dolphin' for advanced, two-stage document layout analysis on images.")
499
  with gr.Row():
500
  with gr.Column():
501
+ model_choice = gr.Radio(
502
+ choices=["Nanonets-OCR-s", "SmolDocling-256M-preview", "MonkeyOCR-Recognition", "ByteDance-s-Dolphin"],
503
+ label="Select Model",
504
+ value="Nanonets-OCR-s"
505
+ )
506
  with gr.Tabs():
507
  with gr.TabItem("Image Inference"):
508
+ image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here... For Dolphin, leave blank to run full document analysis or ask a question about the image.")
509
  image_upload = gr.Image(type="pil", label="Image")
510
  image_submit = gr.Button("Submit", elem_classes="submit-btn")
511
  gr.Examples(
 
520
  examples=video_examples,
521
  inputs=[video_query, video_upload]
522
  )
523
+ with gr.Accordion("Advanced options (for streaming models)", open=False):
524
  max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
525
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
526
  top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
527
  top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
528
  repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
529
  with gr.Column():
530
+ output = gr.Markdown(label="Output", interactive=False)
531
+
 
 
 
 
 
532
  image_submit.click(
533
  fn=generate_image,
534
  inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
 
541
  )
542
 
543
  if __name__ == "__main__":
544
+ demo.queue(max_size=30).launch(share=True)