prithivMLmods commited on
Commit
b001080
Β·
verified Β·
1 Parent(s): 129f25d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -307
app.py CHANGED
@@ -18,387 +18,208 @@ from transformers import (
18
  AutoProcessor,
19
  TextIteratorStreamer,
20
  )
21
- from qwen_vl_utils import process_vision_info
22
 
23
- # Constants
24
- MIN_PIXELS = 3136
25
- MAX_PIXELS = 11289600
26
- IMAGE_FACTOR = 28
27
- MAX_INPUT_TOKEN_LENGTH = 2048
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
 
30
- # Prompts
31
- prompt = """Please output the layout information from the image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
32
 
33
  1. Bbox format: [x1, y1, x2, y2]
34
-
35
  2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
36
-
37
  3. Text Extraction & Formatting Rules:
38
- - Picture: For the 'Picture' category, the text field should be omitted.
39
- - Formula: Format its text as LaTeX.
40
- - Table: Format its text as HTML.
41
- - All Others (Text, Title, etc.): Format their text as Markdown.
42
-
43
  4. Constraints:
44
- - The output text must be the original text from the image, with no translation.
45
  - All layout elements must be sorted according to human reading order.
46
-
47
- 5. Final Output: The entire output must be a single JSON object.
48
  """
49
 
50
- # Load models
 
 
51
  MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-062825"
52
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
53
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
54
- MODEL_ID_M,
55
- trust_remote_code=True,
56
- torch_dtype=torch.float16
57
  ).to(device).eval()
58
 
59
  MODEL_ID_T = "prithivMLmods/Megalodon-OCR-Sync-0713"
60
  processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
61
  model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
62
- MODEL_ID_T,
63
- trust_remote_code=True,
64
- torch_dtype=torch.float16
65
  ).to(device).eval()
66
 
67
  MODEL_ID_C = "nanonets/Nanonets-OCR-s"
68
  processor_c = AutoProcessor.from_pretrained(MODEL_ID_C, trust_remote_code=True)
69
  model_c = Qwen2_5_VLForConditionalGeneration.from_pretrained(
70
- MODEL_ID_C,
71
- trust_remote_code=True,
72
- torch_dtype=torch.float16
73
  ).to(device).eval()
74
 
75
  MODEL_ID_G = "echo840/MonkeyOCR"
76
  SUBFOLDER = "Recognition"
77
  processor_g = AutoProcessor.from_pretrained(
78
- MODEL_ID_G,
79
- trust_remote_code=True,
80
- subfolder=SUBFOLDER
81
  )
82
  model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained(
83
- MODEL_ID_G,
84
- trust_remote_code=True,
85
- subfolder=SUBFOLDER,
86
- torch_dtype=torch.float16
87
  ).to(device).eval()
88
 
89
- # Utility functions
90
- def round_by_factor(number: int, factor: int) -> int:
91
- return round(number / factor) * factor
92
-
93
- def smart_resize(
94
- height: int,
95
- width: int,
96
- factor: int = 28,
97
- min_pixels: int = 3136,
98
- max_pixels: int = 11289600,
99
- ):
100
- if max(height, width) / min(height, width) > 200:
101
- raise ValueError(f"Aspect ratio too extreme: {max(height, width) / min(height, width)}")
102
- h_bar = max(factor, round_by_factor(height, factor))
103
- w_bar = max(factor, round_by_factor(width, factor))
104
- if h_bar * w_bar > max_pixels:
105
- beta = math.sqrt((height * width) / max_pixels)
106
- h_bar = round_by_factor(height / beta, factor)
107
- w_bar = round_by_factor(width / beta, factor)
108
- elif h_bar * w_bar < min_pixels:
109
- beta = math.sqrt(min_pixels / (height * width))
110
- h_bar = round_by_factor(height * beta, factor)
111
- w_bar = round_by_factor(width * beta, factor)
112
- return h_bar, w_bar
113
-
114
- def fetch_image(image_input, min_pixels: int = None, max_pixels: int = None):
115
- if isinstance(image_input, str):
116
- if image_input.startswith(("http://", "https://")):
117
- response = requests.get(image_input)
118
- image = Image.open(BytesIO(response.content)).convert('RGB')
119
- else:
120
- image = Image.open(image_input).convert('RGB')
121
- elif isinstance(image_input, Image.Image):
122
- image = image_input.convert('RGB')
123
- else:
124
- raise ValueError(f"Invalid image input type: {type.image_input)}")
125
- if min_pixels or max_pixels:
126
- min_pixels = min_pixels or MIN_PIXELS
127
- max_pixels = max_pixels or MAX_PIXELS
128
- height, width = smart_resize(image.height, image.width, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels)
129
- image = image.resize((width, height), Image.LANCZOS)
130
- return image
131
-
132
- def is_arabic_text(text: str) -> bool:
133
- if not text:
134
- return False
135
- header_pattern = r'^#{1,6}\s+(.+)$'
136
- paragraph_pattern = r'^(?!#{1,6}\s|!\[|```|\||\s*[-*+]\s|\s*\d+\.\s)(.+)$'
137
- content_text = []
138
- for line in text.split('\n'):
139
- line = line.strip()
140
- if not line:
141
- continue
142
- header_match = re.match(header_pattern, line, re.MULTILINE)
143
- if header_match:
144
- content_text.append(header_match.group(1))
145
- continue
146
- if re.match(paragraph_pattern, line, re.MULTILINE):
147
- content_text.append(line)
148
- if not content_text:
149
- return False
150
- combined_text = ' '.join(content_text)
151
- arabic_chars = 0
152
- total_chars = 0
153
- for char in combined_text:
154
- if char.isalpha():
155
- total_chars += 1
156
- if ('\u0600' <= char <= '\u06FF') or ('\u0750' <= char <= '\u077F') or ('\u08A0' <= char <= '\u08FF'):
157
- arabic_chars += 1
158
- return total_chars > 0 and (arabic_chars / total_chars) > 0.5
159
-
160
- def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = 'text') -> str:
161
- import base64
162
- from io import BytesIO
163
  markdown_lines = []
164
  try:
165
- sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0]))
 
166
  for item in sorted_items:
167
  category = item.get('category', '')
168
- text = item.get(text_key, '')
169
- bbox = item.get('bbox', [])
170
- if category == 'Picture':
171
- if bbox and len(bbox) == 4:
172
- try:
173
- x1, y1, x2, y2 = bbox
174
- x1, y1 = max(0, int(x1)), max(0, int(y1))
175
- x2, y2 = min(image.width, int(x2)), min(image.height, int(y2))
176
- if x2 > x1 and y2 > y1:
177
- cropped_img = image.crop((x1, y1, x2, y2))
178
- buffer = BytesIO()
179
- cropped_img.save(buffer, format='PNG')
180
- img_data = base64.b64encode(buffer.getvalue()).decode()
181
- markdown_lines.append(f"<image-card alt="Image" src="data:image/png;base64,{img_data}" ></image-card>\n")
182
- else:
183
- markdown_lines.append("<image-card alt="Image" src="Image region detected" ></image-card>\n")
184
- except Exception as e:
185
- print(f"Error processing image region: {e}")
186
- markdown_lines.append("<image-card alt="Image" src="Image detected" ></image-card>\n")
187
- else:
188
- markdown_lines.append("<image-card alt="Image" src="Image detected" ></image-card>\n")
189
- elif not text:
190
- continue
191
- elif category == 'Title':
192
- markdown_lines.append(f"# {text}\n")
193
- elif category == 'Section-header':
194
- markdown_lines.append(f"## {text}\n")
195
- elif category == 'Text':
196
- markdown_lines.append(f"{text}\n")
197
- elif category == 'List-item':
198
- markdown_lines.append(f"- {text}\n")
199
  elif category == 'Table':
200
- if text.strip().startswith('<'):
 
 
 
 
 
 
 
201
  markdown_lines.append(f"{text}\n")
202
- else:
203
- markdown_lines.append(f"**Table:** {text}\n")
204
- elif category == 'Formula':
205
- if text.strip().startswith('$') or '\\' in text:
206
- markdown_lines.append(f"$$ \n{text}\n $$\n")
207
- else:
208
- markdown_lines.append(f"**Formula:** {text}\n")
209
- elif category == 'Caption':
210
- markdown_lines.append(f"*{text}*\n")
211
- elif category == 'Footnote':
212
- markdown_lines.append(f"^{text}^\n")
213
- elif category in ['Page-header', 'Page-footer']:
214
- continue
215
  else:
216
- markdown_lines.append(f"{text}\n")
217
- markdown_lines.append("")
218
  except Exception as e:
219
  print(f"Error converting to markdown: {e}")
220
- return str(layout_data)
221
  return "\n".join(markdown_lines)
222
 
 
223
  @spaces.GPU
224
- def inference(model_name: str, image: Image.Image, text: str, max_new_tokens: int = 1024) -> str:
225
- try:
226
- if model_name == "Camel-Doc-OCR-062825":
227
- processor = processor_m
228
- model = model_m
229
- elif model_name == "Megalodon-OCR-Sync-0713":
230
- processor = processor_t
231
- model = model_t
232
- elif model_name == "Nanonets-OCR-s":
233
- processor = processor_c
234
- model = model_c
235
- elif model_name == "MonkeyOCR-Recognition":
236
- processor = processor_g
237
- model = model_g
238
- else:
239
- raise ValueError(f"Invalid model selected: {model_name}")
240
-
241
- if image is None:
242
- yield "Please upload an image.", "Please upload an image."
243
- return
244
-
245
- messages = [{
246
- "role": "user",
247
- "content": [
248
- {"type": "image", "image": image},
249
- {"type": "text", "text": text},
250
- ]
251
- }]
252
- prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
253
- inputs = processor(
254
- text=[prompt_full],
255
- images=[image],
256
- return_tensors="pt",
257
- padding=True,
258
- truncation=False,
259
- max_length=MAX_INPUT_TOKEN_LENGTH
260
- ).to(device)
261
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
262
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
263
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
264
- thread.start()
265
- buffer = ""
266
- for new_text in streamer:
267
- buffer += new_text
268
- buffer = buffer.replace("<|im_end|>", "")
269
- time.sleep(0.01)
270
- yield buffer, buffer
271
- except Exception as e:
272
- print(f"Error during inference: {e}")
273
- traceback.print_exc()
274
- yield f"Error during inference: {str(e)}", f"Error during inference: {str(e)}"
275
-
276
- def process_image(
277
- model_name: str,
278
- image: Image.Image,
279
- min_pixels: Optional[int] = None,
280
- max_pixels: Optional[int] = None,
281
- max_new_tokens: int = 1024
282
- ):
283
- try:
284
- if min_pixels or max_pixels:
285
- image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
286
- buffer = ""
287
- for raw_output, _ in inference(model_name, image, prompt, max_new_tokens):
288
- buffer = raw_output
289
- yield buffer, None # Yield raw OCR stream and None for JSON during processing
290
  try:
291
- json_match = re.search(r'```json
292
- json_str = json_match.group(1) if json_match else buffer
 
 
 
293
  layout_data = json.loads(json_str)
294
- yield buffer, layout_data # Final yield with raw OCR and parsed JSON
295
- except json.JSONDecodeError:
296
- print("Failed to parse JSON output, using raw output")
297
- yield buffer, None # If JSON parsing fails, yield raw OCR with no JSON
298
- except Exception as e:
299
- print(f"Error processing image: {e}")
300
- traceback.print_exc()
301
- yield f"Error processing image: {str(e)}", None
302
-
303
- def load_file_for_preview(file_path: str) -> Tuple[Optional[Image.Image], str]:
304
- if not file_path or not os.path.exists(file_path):
305
- return None, "No file selected"
306
- file_ext = os.path.splitext(file_path)[1].lower()
307
- try:
308
- if file_ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
309
- image = Image.open(file_path).convert('RGB')
310
- return image, "Image loaded"
311
- else:
312
- return None, f"Unsupported file format: {file_ext}"
313
- except Exception as e:
314
- print(f"Error loading file: {e}")
315
- return None, f"Error loading file: {str(e)}"
316
-
317
  def create_gradio_interface():
 
318
  css = """
319
  .main-container { max-width: 1400px; margin: 0 auto; }
320
- .header-text { text-align: center; color: #2c3e50; margin-bottom: 20px; }
321
- .process-button {
322
- border: none !important;
323
- color: white !important;
324
- font-weight: bold !important;
325
- background-color: blue !important;}
326
- .process-button:hover {
327
- background-color: darkblue !important;
328
- transform: translateY(-2px) !important;
329
- box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
330
- .info-box { border: 1px solid #dee2e6; border-radius: 8px; padding: 15px; margin: 10px 0; }
331
- .page-info { text-align: center; padding: 8px 16px; border-radius: 20px; font-weight: bold; margin: 10px 0; }
332
- .model-status { padding: 10px; border-radius: 8px; margin: 10px 0; text-align: center; font-weight: bold; }
333
- .status-ready { background: #d1edff; color: #0c5460; border: 1px solid #b8daff; }
334
  """
335
  with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
336
  gr.HTML("""
337
  <div class="title" style="text-align: center">
338
  <h1>Dot<span style="color: red;">●</span><strong></strong>OCR Comparator</h1>
339
  <p style="font-size: 1.1em; color: #6b7280; margin-bottom: 0.6em;">
340
- Advanced vision-language model for image to markdown document processing
341
  </p>
342
  </div>
343
  """)
 
344
  with gr.Row():
 
345
  with gr.Column(scale=1):
346
  model_choice = gr.Radio(
347
  choices=["Camel-Doc-OCR-062825", "MonkeyOCR-Recognition", "Nanonets-OCR-s", "Megalodon-OCR-Sync-0713"],
348
- label="Select Model",
349
- value="Camel-Doc-OCR-062825"
350
  )
351
- file_input = gr.File(
352
- label="Upload Image",
353
- file_types =[".jpg", ".jpeg", ".png", ".bmp", ".tiff"],
354
- type="filepath"
355
  )
356
- image_preview = gr.Image(label="Preview", type="pil", interactive=False, height=300)
357
  with gr.Accordion("Advanced Settings", open=False):
358
- max_new_tokens = gr.Slider(minimum=1000, maximum=32000, value=24000, step=1000, label="Max New Tokens")
359
- min_pixels = gr.Number(value=MIN_PIXELS, label="Min Pixels")
360
- max_pixels = gr.Number(value=MAX_PIXELS, label="Max Pixels")
361
  process_btn = gr.Button("πŸš€ Process Document", variant="primary", elem_classes=["process-button"], size="lg")
362
  clear_btn = gr.Button("πŸ—‘οΈ Clear All", variant="secondary")
 
 
363
  with gr.Column(scale=2):
364
- with gr.Tabs():
365
- with gr.Tab("πŸ“ Extracted Content"):
366
- output = gr.Textbox(label="Raw OCR Stream", interactive=False, lines=10, show_copy_button=True)
367
- with gr.Tab("πŸ“‹ Layout Analysis Results"):
368
- json_output = gr.JSON(label="Layout Analysis Results", value=None)
369
- def process_document(model_name, file_path, max_tokens, min_pix, max_pix):
370
- try:
371
- if not file_path:
372
- return "Please upload an image.", None
373
- image, status = load_file_for_preview(file_path)
374
- if image is None:
375
- return status, None
376
- for raw_output, layout_result in process_image(model_name, image, min_pixels=int(min_pix) if min_pix else None, max_pixels=int(max_pix) if max_pix else None, max_new_tokens=max_tokens):
377
- yield raw_output, layout_result
378
- except Exception as e:
379
- error_msg = f"Error processing document: {str(e)}"
380
- print(error_msg)
381
- traceback.print_exc()
382
- yield error_msg, None
383
- def handle_file_upload(file_path):
384
- if not file_path:
385
- return None, "No file loaded"
386
- image, page_info = load_file_for_preview(file_path)
387
- return image, page_info
388
- def clear_all():
389
- return None, None, "No file loaded", None
390
- file_input.change(handle_file_upload, inputs=[file_input], outputs=[image_preview, output])
391
  process_btn.click(
392
- process_document,
393
- inputs=[model_choice, file_input, max_new_tokens, min_pixels, max_pixels],
394
- outputs=[output, json_output]
395
  )
396
  clear_btn.click(
397
- clear_all,
398
- outputs=[file_input, image_preview, output, json_output]
399
  )
400
  return demo
401
 
402
  if __name__ == "__main__":
403
  demo = create_gradio_interface()
404
- demo.queue(max_size=10).launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True, show_error=True)
 
18
  AutoProcessor,
19
  TextIteratorStreamer,
20
  )
 
21
 
22
+ # --- Constants and Model Setup ---
23
+ MAX_INPUT_TOKEN_LENGTH = 4096
 
 
 
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
25
 
26
+ # --- Prompts for Different Tasks ---
27
+ layout_prompt = """Please output the layout information from the image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
28
 
29
  1. Bbox format: [x1, y1, x2, y2]
 
30
  2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
 
31
  3. Text Extraction & Formatting Rules:
32
+ - For tables, provide the content in a structured JSON format.
33
+ - For all other elements, provide the plain text.
 
 
 
34
  4. Constraints:
35
+ - The output must be the original text from the image.
36
  - All layout elements must be sorted according to human reading order.
37
+ 5. Final Output: The entire output must be a single JSON object wrapped in ```json ... ```.
 
38
  """
39
 
40
+ ocr_prompt = "Perform precise OCR on the image. Extract all text content, maintaining the original structure, paragraphs, and tables as formatted markdown."
41
+
42
+ # --- Model Loading ---
43
  MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-062825"
44
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
45
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
46
+ MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.float16
 
 
47
  ).to(device).eval()
48
 
49
  MODEL_ID_T = "prithivMLmods/Megalodon-OCR-Sync-0713"
50
  processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
51
  model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
52
+ MODEL_ID_T, trust_remote_code=True, torch_dtype=torch.float16
 
 
53
  ).to(device).eval()
54
 
55
  MODEL_ID_C = "nanonets/Nanonets-OCR-s"
56
  processor_c = AutoProcessor.from_pretrained(MODEL_ID_C, trust_remote_code=True)
57
  model_c = Qwen2_5_VLForConditionalGeneration.from_pretrained(
58
+ MODEL_ID_C, trust_remote_code=True, torch_dtype=torch.float16
 
 
59
  ).to(device).eval()
60
 
61
  MODEL_ID_G = "echo840/MonkeyOCR"
62
  SUBFOLDER = "Recognition"
63
  processor_g = AutoProcessor.from_pretrained(
64
+ MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER
 
 
65
  )
66
  model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained(
67
+ MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER, torch_dtype=torch.float16
 
 
 
68
  ).to(device).eval()
69
 
70
+ # --- Utility Functions ---
71
+ def layoutjson2md(layout_data: List[Dict]) -> str:
72
+ """Converts the structured JSON from Layout Analysis into formatted Markdown."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  markdown_lines = []
74
  try:
75
+ # Sort items by reading order (top-to-bottom, left-to-right)
76
+ sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0,0,0,0])[1], x.get('bbox', [0,0,0,0])[0]))
77
  for item in sorted_items:
78
  category = item.get('category', '')
79
+ text = item.get('text', '')
80
+ if not text: continue
81
+
82
+ if category == 'Title': markdown_lines.append(f"# {text}\n")
83
+ elif category == 'Section-header': markdown_lines.append(f"## {text}\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  elif category == 'Table':
85
+ # Handle structured table JSON
86
+ if isinstance(text, dict) and 'header' in text and 'rows' in text:
87
+ header = '| ' + ' | '.join(map(str, text['header'])) + ' |'
88
+ separator = '| ' + ' | '.join(['---'] * len(text['header'])) + ' |'
89
+ rows = ['| ' + ' | '.join(map(str, row)) + ' |' for row in text['rows']]
90
+ markdown_lines.extend([header, separator] + rows)
91
+ markdown_lines.append("\n")
92
+ else: # Fallback for simple text
93
  markdown_lines.append(f"{text}\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  else:
95
+ markdown_lines.append(f"{text}\n")
 
96
  except Exception as e:
97
  print(f"Error converting to markdown: {e}")
98
+ return "### Error converting JSON to Markdown."
99
  return "\n".join(markdown_lines)
100
 
101
+ # --- Core Application Logic ---
102
  @spaces.GPU
103
+ def process_document_stream(model_name: str, task_choice: str, image: Image.Image, max_new_tokens: int):
104
+ """
105
+ Main generator function that handles both OCR and Layout Analysis tasks.
106
+ """
107
+ if image is None:
108
+ yield "Please upload an image.", "Please upload an image.", None
109
+ return
110
+
111
+ # 1. Select prompt based on user's task choice
112
+ text_prompt = ocr_prompt if task_choice == "Content Extraction" else layout_prompt
113
+
114
+ # 2. Select model and processor
115
+ if model_name == "Camel-Doc-OCR-062825": processor, model = processor_m, model_m
116
+ elif model_name == "Megalodon-OCR-Sync-0713": processor, model = processor_t, model_t
117
+ elif model_name == "Nanonets-OCR-s": processor, model = processor_c, model_c
118
+ elif model_name == "MonkeyOCR-Recognition": processor, model = processor_g, model_g
119
+ else:
120
+ yield "Invalid model selected.", "Invalid model selected.", None
121
+ return
122
+
123
+ # 3. Prepare model inputs and streamer
124
+ messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": text_prompt}]}]
125
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
126
+ inputs = processor(text=[prompt_full], images=[image], return_tensors="pt", padding=True, truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH).to(device)
127
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
128
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
129
+
130
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
131
+ thread.start()
132
+
133
+ # 4. Stream raw output to the UI in real-time
134
+ buffer = ""
135
+ for new_text in streamer:
136
+ buffer += new_text
137
+ buffer = buffer.replace("<|im_end|>", "")
138
+ time.sleep(0.01)
139
+ yield buffer, "⏳ Processing...", {"status": "streaming"}
140
+
141
+ # 5. Post-process the final buffer based on the selected task
142
+ if task_choice == "Content Extraction":
143
+ # For OCR, the buffer is the final result.
144
+ yield buffer, buffer, None
145
+ else: # Layout Analysis
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  try:
147
+ json_match = re.search(r'```json\s*([\s\S]+?)\s*```', buffer)
148
+ if not json_match:
149
+ raise json.JSONDecodeError("JSON object not found in output.", buffer, 0)
150
+
151
+ json_str = json_match.group(1)
152
  layout_data = json.loads(json_str)
153
+ markdown_content = layoutjson2md(layout_data)
154
+
155
+ yield buffer, markdown_content, layout_data
156
+ except Exception as e:
157
+ error_md = f"❌ **Error:** Failed to parse Layout JSON.\n\n**Details:**\n`{str(e)}`"
158
+ error_json = {"error": "ProcessingError", "details": str(e), "raw_output": buffer}
159
+ yield buffer, error_md, error_json
160
+
161
+ # --- Gradio UI Definition ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  def create_gradio_interface():
163
+ """Builds and returns the Gradio web interface."""
164
  css = """
165
  .main-container { max-width: 1400px; margin: 0 auto; }
166
+ .process-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important;}
167
+ .process-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
 
 
 
 
 
 
 
 
 
 
 
 
168
  """
169
  with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
170
  gr.HTML("""
171
  <div class="title" style="text-align: center">
172
  <h1>Dot<span style="color: red;">●</span><strong></strong>OCR Comparator</h1>
173
  <p style="font-size: 1.1em; color: #6b7280; margin-bottom: 0.6em;">
174
+ Advanced Vision-Language Model for Image Content and Layout Extraction
175
  </p>
176
  </div>
177
  """)
178
+
179
  with gr.Row():
180
+ # Left Column (Inputs)
181
  with gr.Column(scale=1):
182
  model_choice = gr.Radio(
183
  choices=["Camel-Doc-OCR-062825", "MonkeyOCR-Recognition", "Nanonets-OCR-s", "Megalodon-OCR-Sync-0713"],
184
+ label="Select Model", value="Camel-Doc-OCR-062825"
 
185
  )
186
+ task_choice = gr.Radio(
187
+ choices=["Content Extraction", "Layout Analysis"],
188
+ label="Select Task", value="Content Extraction"
 
189
  )
190
+ image_input = gr.Image(label="Upload Image", type="pil", sources=['upload'])
191
  with gr.Accordion("Advanced Settings", open=False):
192
+ max_new_tokens = gr.Slider(minimum=512, maximum=8192, value=4096, step=256, label="Max New Tokens")
193
+
 
194
  process_btn = gr.Button("πŸš€ Process Document", variant="primary", elem_classes=["process-button"], size="lg")
195
  clear_btn = gr.Button("πŸ—‘οΈ Clear All", variant="secondary")
196
+
197
+ # Right Column (Outputs)
198
  with gr.Column(scale=2):
199
+ with gr.Tabs() as tabs:
200
+ with gr.Tab("πŸ“ Extracted Content", id=0):
201
+ raw_output_stream = gr.Textbox(label="Raw Model Output Stream", interactive=False, lines=15, show_copy_button=True)
202
+ with gr.Accordion("(Formatted Result)", open=True):
203
+ markdown_output = gr.Markdown(label="Formatted Markdown")
204
+
205
+ with gr.Tab("πŸ“‹ Layout Analysis Results", id=1):
206
+ json_output = gr.JSON(label="Structured Layout Data (JSON)")
207
+
208
+ # Event Handlers
209
+ def clear_all_outputs():
210
+ return None, "Raw output will appear here.", "Formatted results will appear here.", None
211
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  process_btn.click(
213
+ fn=process_document_stream,
214
+ inputs=[model_choice, task_choice, image_input, max_new_tokens],
215
+ outputs=[raw_output_stream, markdown_output, json_output]
216
  )
217
  clear_btn.click(
218
+ clear_all_outputs,
219
+ outputs=[image_input, raw_output_stream, markdown_output, json_output]
220
  )
221
  return demo
222
 
223
  if __name__ == "__main__":
224
  demo = create_gradio_interface()
225
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860, show_error=True)