trevorpfiz commited on
Commit
4aa9a45
·
1 Parent(s): ae9d014

fix: unexpected keyword argument 'file_name'

Browse files
src/vlm_playground/app.py CHANGED
@@ -1,5 +1,7 @@
1
  from .preview_app import create_blocks_app
2
 
 
 
3
 
4
  def run() -> None:
5
  demo = create_blocks_app()
 
1
  from .preview_app import create_blocks_app
2
 
3
+ # from .preview_app_local import create_blocks_app as create_blocks_app_local
4
+
5
 
6
  def run() -> None:
7
  demo = create_blocks_app()
src/vlm_playground/preview_app.py CHANGED
@@ -512,12 +512,8 @@ def create_blocks_app():
512
  processed_view = gr.Image(type="pil", height=520)
513
 
514
  with gr.Row():
515
- download_jsonl = gr.DownloadButton(
516
- label="Download JSONL", file_name="results.jsonl"
517
- )
518
- download_markdown = gr.DownloadButton(
519
- label="Download Markdown", file_name="results.md"
520
- )
521
 
522
  # ===== Handlers =====
523
  def on_template_change(choice: str) -> str:
@@ -734,7 +730,10 @@ def create_blocks_app():
734
  obj = {"page": i + 1, "layout": res["layout_result"]}
735
  lines.append(json.dumps(obj, ensure_ascii=False))
736
  content = "\n".join(lines) if lines else ""
737
- return gr.DownloadButton.update(value=content.encode("utf-8"))
 
 
 
738
 
739
  def download_current_markdown(state: Dict[str, Any]):
740
  if not state.get("parsed"):
@@ -744,7 +743,10 @@ def create_blocks_app():
744
  if res and res.get("markdown"):
745
  chunks.append(f"## Page {i + 1}\n\n{res['markdown']}")
746
  content = "\n\n---\n\n".join(chunks) if chunks else ""
747
- return gr.DownloadButton.update(value=content.encode("utf-8"))
 
 
 
748
 
749
  # Wire events
750
  template.change(on_template_change, inputs=[template], outputs=[prompt_text])
 
512
  processed_view = gr.Image(type="pil", height=520)
513
 
514
  with gr.Row():
515
+ download_jsonl = gr.DownloadButton(label="Download JSONL")
516
+ download_markdown = gr.DownloadButton(label="Download Markdown")
 
 
 
 
517
 
518
  # ===== Handlers =====
519
  def on_template_change(choice: str) -> str:
 
730
  obj = {"page": i + 1, "layout": res["layout_result"]}
731
  lines.append(json.dumps(obj, ensure_ascii=False))
732
  content = "\n".join(lines) if lines else ""
733
+ out_path = os.path.join(TMP_DIR, "results.jsonl")
734
+ with open(out_path, "w", encoding="utf-8") as f:
735
+ f.write(content)
736
+ return gr.DownloadButton.update(value=out_path)
737
 
738
  def download_current_markdown(state: Dict[str, Any]):
739
  if not state.get("parsed"):
 
743
  if res and res.get("markdown"):
744
  chunks.append(f"## Page {i + 1}\n\n{res['markdown']}")
745
  content = "\n\n---\n\n".join(chunks) if chunks else ""
746
+ out_path = os.path.join(TMP_DIR, "results.md")
747
+ with open(out_path, "w", encoding="utf-8") as f:
748
+ f.write(content)
749
+ return gr.DownloadButton.update(value=out_path)
750
 
751
  # Wire events
752
  template.change(on_template_change, inputs=[template], outputs=[prompt_text])
src/vlm_playground/preview_app_local.py ADDED
@@ -0,0 +1,786 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import types
3
+ import sys
4
+ import hashlib
5
+ import json
6
+ import math
7
+ import os
8
+ import re
9
+ from io import BytesIO
10
+ from typing import Any, Dict, List, Optional, Tuple
11
+
12
+ import fitz # PyMuPDF
13
+ import gradio as gr
14
+ import requests
15
+ import torch
16
+ from huggingface_hub import snapshot_download
17
+ from PIL import Image, ImageDraw, ImageFont
18
+ from qwen_vl_utils import process_vision_info
19
+ from transformers import AutoModelForCausalLM, AutoProcessor
20
+
21
+ from .utils.constants import IMAGE_FACTOR, MAX_PIXELS, MIN_PIXELS
22
+ from .utils.prompts import dict_promptmode_to_prompt
23
+
24
+ APP_TITLE = "PreviewSpace — VLM Playground (Local)"
25
+ TMP_DIR = "/tmp/previewspace"
26
+ MODELS_DIR = os.path.join(TMP_DIR, "models")
27
+ DOTS_REPO_ID = "rednote-hilab/dots.ocr"
28
+ DOTS_LOCAL_DIR = os.path.join(MODELS_DIR, "dots.ocr")
29
+
30
+ LOCAL_DEFAULT_MAX_NEW_TOKENS = 2048
31
+
32
+ os.makedirs(TMP_DIR, exist_ok=True)
33
+ os.makedirs(MODELS_DIR, exist_ok=True)
34
+
35
+
36
+ def round_by_factor(number: int, factor: int) -> int:
37
+ return round(number / factor) * factor
38
+
39
+
40
+ def smart_resize(
41
+ height: int,
42
+ width: int,
43
+ factor: int = IMAGE_FACTOR,
44
+ min_pixels: int = MIN_PIXELS,
45
+ max_pixels: int = MAX_PIXELS,
46
+ ) -> Tuple[int, int]:
47
+ if max(height, width) / min(height, width) > 200:
48
+ raise ValueError("absolute aspect ratio must be smaller than 200")
49
+ h_bar = max(factor, round_by_factor(height, factor))
50
+ w_bar = max(factor, round_by_factor(width, factor))
51
+
52
+ if h_bar * w_bar > max_pixels:
53
+ beta = math.sqrt((height * width) / max_pixels)
54
+ h_bar = round_by_factor(height / beta, factor)
55
+ w_bar = round_by_factor(width / beta, factor)
56
+ elif h_bar * w_bar < min_pixels:
57
+ beta = math.sqrt(min_pixels / (height * width))
58
+ h_bar = round_by_factor(height * beta, factor)
59
+ w_bar = round_by_factor(width * beta, factor)
60
+ return int(h_bar), int(w_bar)
61
+
62
+
63
+ def fetch_image(image_input: Any) -> Image.Image:
64
+ if isinstance(image_input, str):
65
+ if image_input.startswith(("http://", "https://")):
66
+ response = requests.get(image_input, timeout=60)
67
+ image = Image.open(BytesIO(response.content)).convert("RGB")
68
+ else:
69
+ image = Image.open(image_input).convert("RGB")
70
+ elif isinstance(image_input, Image.Image):
71
+ image = image_input.convert("RGB")
72
+ else:
73
+ raise ValueError(f"Invalid image input type: {type(image_input)}")
74
+ return image
75
+
76
+
77
+ def load_images_from_pdf(pdf_path: str) -> List[Image.Image]:
78
+ images: List[Image.Image] = []
79
+ pdf_document = fitz.open(pdf_path)
80
+ try:
81
+ for page_idx in range(len(pdf_document)):
82
+ page = pdf_document.load_page(page_idx)
83
+ pix = page.get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
84
+ img_data = pix.tobytes("ppm")
85
+ image = Image.open(BytesIO(img_data)).convert("RGB")
86
+ images.append(image)
87
+ finally:
88
+ pdf_document.close()
89
+ return images
90
+
91
+
92
+ def file_checksum(path: str, chunk_size: int = 1 << 20) -> str:
93
+ hasher = hashlib.sha256()
94
+ with open(path, "rb") as f:
95
+ while True:
96
+ chunk = f.read(chunk_size)
97
+ if not chunk:
98
+ break
99
+ hasher.update(chunk)
100
+ return hasher.hexdigest()
101
+
102
+
103
+ def draw_layout_on_image(image: Image.Image, layout_data: List[Dict]) -> Image.Image:
104
+ img = image.copy()
105
+ draw = ImageDraw.Draw(img)
106
+ colors = {
107
+ "Caption": "#FF6B6B",
108
+ "Footnote": "#4ECDC4",
109
+ "Formula": "#45B7D1",
110
+ "List-item": "#96CEB4",
111
+ "Page-footer": "#FFEAA7",
112
+ "Page-header": "#DDA0DD",
113
+ "Picture": "#FFD93D",
114
+ "Section-header": "#6C5CE7",
115
+ "Table": "#FD79A8",
116
+ "Text": "#74B9FF",
117
+ "Title": "#E17055",
118
+ }
119
+
120
+ try:
121
+ try:
122
+ font = ImageFont.truetype(
123
+ "/System/Library/Fonts/Supplemental/Arial Bold.ttf", 12
124
+ )
125
+ except Exception:
126
+ try:
127
+ font = ImageFont.truetype(
128
+ "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 12
129
+ )
130
+ except Exception:
131
+ font = ImageFont.load_default()
132
+
133
+ for item in layout_data:
134
+ bbox = item.get("bbox")
135
+ category = item.get("category")
136
+ if not bbox or not category:
137
+ continue
138
+ color = colors.get(category, "#000000")
139
+ draw.rectangle(bbox, outline=color, width=2)
140
+ label = str(category)
141
+ label_bbox = draw.textbbox((0, 0), label, font=font)
142
+ label_w = label_bbox[2] - label_bbox[0]
143
+ label_h = label_bbox[3] - label_bbox[1]
144
+ x1, y1 = int(bbox[0]), int(bbox[1])
145
+ lx = x1
146
+ ly = max(0, y1 - label_h - 2)
147
+ draw.rectangle([lx, ly, lx + label_w + 4, ly + label_h + 2], fill=color)
148
+ draw.text((lx + 2, ly + 1), label, fill="white", font=font)
149
+ except Exception:
150
+ pass
151
+ return img
152
+
153
+
154
+ def is_arabic_text(text: str) -> bool:
155
+ if not text:
156
+ return False
157
+ header_pattern = r"^#{1,6}\s+(.+)$"
158
+ paragraph_pattern = r"^(?!#{1,6}\s|!\[|```|\||\s*[-*+]\s|\s*\d+\.\s)(.+)$"
159
+ content_lines: List[str] = []
160
+ for line in text.split("\n"):
161
+ s = line.strip()
162
+ if not s:
163
+ continue
164
+ m = re.match(header_pattern, s)
165
+ if m:
166
+ content_lines.append(m.group(1))
167
+ continue
168
+ if re.match(paragraph_pattern, s):
169
+ content_lines.append(s)
170
+ if not content_lines:
171
+ return False
172
+ combined = " ".join(content_lines)
173
+ arabic = 0
174
+ total = 0
175
+ for ch in combined:
176
+ if ch.isalpha():
177
+ total += 1
178
+ if (
179
+ ("\u0600" <= ch <= "\u06ff")
180
+ or ("\u0750" <= ch <= "\u077f")
181
+ or ("\u08a0" <= ch <= "\u08ff")
182
+ ):
183
+ arabic += 1
184
+ if total == 0:
185
+ return False
186
+ return (arabic / total) > 0.5
187
+
188
+
189
+ def extract_json(text: str) -> Optional[Dict[str, Any]]:
190
+ if not text:
191
+ return None
192
+ try:
193
+ return json.loads(text)
194
+ except Exception:
195
+ pass
196
+ brace_start = text.find("{")
197
+ brace_end = text.rfind("}")
198
+ if 0 <= brace_start < brace_end:
199
+ snippet = text[brace_start : brace_end + 1]
200
+ try:
201
+ return json.loads(snippet)
202
+ except Exception:
203
+ pass
204
+ fenced = re.findall(r"```json\s*([\s\S]*?)\s*```", text)
205
+ for block in fenced:
206
+ try:
207
+ return json.loads(block)
208
+ except Exception:
209
+ continue
210
+ return None
211
+
212
+
213
+ model: Optional[AutoModelForCausalLM] = None
214
+ processor: Optional[AutoProcessor] = None
215
+
216
+
217
+ def ensure_model_loaded() -> Tuple[AutoModelForCausalLM, AutoProcessor]:
218
+ global model, processor
219
+ if model is not None and processor is not None:
220
+ return model, processor
221
+
222
+ os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1")
223
+ snapshot_download(
224
+ repo_id=DOTS_REPO_ID,
225
+ local_dir=DOTS_LOCAL_DIR,
226
+ local_dir_use_symlinks=False,
227
+ )
228
+
229
+ # Work around transformers dynamic module parent package issue with repo name containing a dot
230
+ # Ensure 'transformers_modules' and 'transformers_modules.dots' exist as packages
231
+ if "transformers_modules" not in sys.modules:
232
+ pkg = types.ModuleType("transformers_modules")
233
+ pkg.__path__ = [] # type: ignore[attr-defined]
234
+ sys.modules["transformers_modules"] = pkg
235
+ if "transformers_modules.dots" not in sys.modules:
236
+ subpkg = types.ModuleType("transformers_modules.dots")
237
+ subpkg.__path__ = [] # type: ignore[attr-defined]
238
+ sys.modules["transformers_modules.dots"] = subpkg
239
+
240
+ use_mps = torch.backends.mps.is_available()
241
+ dtype = (
242
+ torch.float16
243
+ if use_mps
244
+ else (torch.bfloat16 if torch.cuda.is_available() else torch.float32)
245
+ )
246
+
247
+ model = AutoModelForCausalLM.from_pretrained(
248
+ DOTS_LOCAL_DIR,
249
+ torch_dtype=dtype,
250
+ trust_remote_code=True,
251
+ low_cpu_mem_usage=True,
252
+ )
253
+ if use_mps:
254
+ model.to("mps")
255
+
256
+ proc = AutoProcessor.from_pretrained(DOTS_LOCAL_DIR, trust_remote_code=True)
257
+ processor = proc
258
+ return model, processor
259
+
260
+
261
+ def run_inference(
262
+ image: Image.Image,
263
+ prompt_text: str,
264
+ max_new_tokens: int = LOCAL_DEFAULT_MAX_NEW_TOKENS,
265
+ ) -> str:
266
+ mdl, proc = ensure_model_loaded()
267
+ messages = [
268
+ {
269
+ "role": "user",
270
+ "content": [
271
+ {"type": "image", "image": image},
272
+ {"type": "text", "text": prompt_text},
273
+ ],
274
+ }
275
+ ]
276
+ text = proc.apply_chat_template(
277
+ messages, tokenize=False, add_generation_prompt=True
278
+ )
279
+ image_inputs, video_inputs = process_vision_info(messages)
280
+ inputs = proc(
281
+ text=[text],
282
+ images=image_inputs,
283
+ videos=video_inputs,
284
+ padding=True,
285
+ return_tensors="pt",
286
+ )
287
+ device = (
288
+ "mps"
289
+ if torch.backends.mps.is_available()
290
+ else ("cuda" if torch.cuda.is_available() else "cpu")
291
+ )
292
+ inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
293
+ with torch.no_grad():
294
+ generated_ids = mdl.generate(
295
+ **inputs,
296
+ max_new_tokens=int(max_new_tokens),
297
+ do_sample=False,
298
+ temperature=0.1,
299
+ )
300
+ trimmed = [
301
+ out_ids[len(in_ids) :]
302
+ for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)
303
+ ]
304
+ output_text = processor.batch_decode(
305
+ trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
306
+ )
307
+ return output_text[0] if output_text else ""
308
+
309
+
310
+ def process_single_image(
311
+ image: Image.Image,
312
+ prompt_text: str,
313
+ max_new_tokens: int,
314
+ ) -> Dict[str, Any]:
315
+ img = fetch_image(image)
316
+ raw = run_inference(img, prompt_text, max_new_tokens=max_new_tokens)
317
+ result: Dict[str, Any] = {
318
+ "original_image": img,
319
+ "processed_image": img,
320
+ "raw_output": raw,
321
+ "layout_result": None,
322
+ "markdown": None,
323
+ }
324
+ data = extract_json(raw)
325
+ if isinstance(data, dict):
326
+ result["layout_result"] = data
327
+ items = data.get("elements", data.get("elements_list", data.get("content", [])))
328
+ if isinstance(items, list):
329
+ result["processed_image"] = draw_layout_on_image(img, items)
330
+ result["markdown"] = layoutjson2md(img, items)
331
+ if result["markdown"] is None:
332
+ result["markdown"] = raw
333
+ return result
334
+
335
+
336
+ def layoutjson2md(
337
+ image: Image.Image, layout_data: List[Dict], text_key: str = "text"
338
+ ) -> str:
339
+ lines: List[str] = []
340
+ try:
341
+ items = sorted(
342
+ layout_data,
343
+ key=lambda x: (
344
+ x.get("bbox", [0, 0, 0, 0])[1],
345
+ x.get("bbox", [0, 0, 0, 0])[0],
346
+ ),
347
+ )
348
+ for item in items:
349
+ category = item.get("category", "")
350
+ text = item.get(text_key, "")
351
+ if category == "Title" and text:
352
+ lines.append(f"# {text}\n")
353
+ elif category == "Section-header" and text:
354
+ lines.append(f"## {text}\n")
355
+ elif category == "List-item" and text:
356
+ lines.append(f"- {text}\n")
357
+ elif category == "Table" and text:
358
+ if text.strip().startswith("<"):
359
+ lines.append(text + "\n")
360
+ else:
361
+ lines.append(f"**Table:** {text}\n")
362
+ elif category == "Formula" and text:
363
+ if text.strip().startswith("$") or "\\" in text:
364
+ lines.append(f"$$\n{text}\n$$\n")
365
+ else:
366
+ lines.append(f"**Formula:** {text}\n")
367
+ elif category == "Caption" and text:
368
+ lines.append(f"*{text}*\n")
369
+ elif category in ["Page-header", "Page-footer"]:
370
+ continue
371
+ elif category == "Picture":
372
+ continue
373
+ elif text:
374
+ lines.append(f"{text}\n")
375
+ lines.append("")
376
+ except Exception:
377
+ return json.dumps(layout_data, ensure_ascii=False)
378
+ return "\n".join(lines)
379
+
380
+
381
+ def create_blocks_app():
382
+ css = """
383
+ .main-container { max-width: 1500px; margin: 0 auto; }
384
+ .header-text { text-align: center; color: #1f2937; margin-bottom: 12px; }
385
+ .page-info { text-align: center; padding: 8px 16px; border-radius: 20px; font-weight: 600; }
386
+ .process-button { border: none !important; color: white !important; font-weight: 700 !important; }
387
+ """
388
+
389
+ with gr.Blocks(theme=gr.themes.Soft(), css=css, title=APP_TITLE) as demo:
390
+ doc_state = gr.State(
391
+ {
392
+ "images": [],
393
+ "current_page": 0,
394
+ "total_pages": 0,
395
+ "file_type": None,
396
+ "checksum": None,
397
+ "results": [],
398
+ "parsed": False,
399
+ }
400
+ )
401
+
402
+ cache_state = gr.State({})
403
+
404
+ gr.HTML(
405
+ """
406
+ <div class=\"header-text\">
407
+ <h2>VLM Playground — dots.ocr (Local)</h2>
408
+ <p>Optimized defaults for Apple Silicon / CPU dev.</p>
409
+ </div>
410
+ """
411
+ )
412
+
413
+ with gr.Row(elem_classes=["main-container"]):
414
+ with gr.Column(scale=4):
415
+ file_input = gr.File(
416
+ label="Upload PDF or Image",
417
+ file_types=[
418
+ ".pdf",
419
+ ".png",
420
+ ".jpg",
421
+ ".jpeg",
422
+ ".bmp",
423
+ ".tiff",
424
+ ".webp",
425
+ ],
426
+ type="filepath",
427
+ )
428
+
429
+ with gr.Group():
430
+ template = gr.Dropdown(
431
+ label="Prompt Template",
432
+ choices=["Layout Extraction"],
433
+ value="Layout Extraction",
434
+ )
435
+ prompt_text = gr.Textbox(
436
+ label="Current Prompt",
437
+ value=dict_promptmode_to_prompt.get("prompt_layout_all_en", ""),
438
+ lines=6,
439
+ )
440
+
441
+ with gr.Row():
442
+ parse_button = gr.Button(
443
+ "Parse", variant="primary", elem_classes=["process-button"]
444
+ )
445
+ clear_button = gr.Button("Clear")
446
+
447
+ with gr.Accordion("Advanced", open=False):
448
+ max_new_tokens = gr.Slider(
449
+ minimum=256,
450
+ maximum=8192,
451
+ value=LOCAL_DEFAULT_MAX_NEW_TOKENS,
452
+ step=128,
453
+ label="Max new tokens",
454
+ )
455
+ page_range = gr.Textbox(
456
+ label="Page selection",
457
+ placeholder="e.g., 1-3,5 (blank = current page, 'all' = all pages)",
458
+ )
459
+
460
+ with gr.Column(scale=5):
461
+ preview_image = gr.Image(label="Page Preview", type="pil", height=520)
462
+ with gr.Row():
463
+ prev_btn = gr.Button("◀ Prev")
464
+ page_info = gr.HTML('<div class="page-info">No file</div>')
465
+ next_btn = gr.Button("Next ▶")
466
+ with gr.Row():
467
+ page_jump = gr.Number(value=1, label="Page #", precision=0)
468
+ jump_btn = gr.Button("Go")
469
+
470
+ with gr.Column(scale=6):
471
+ with gr.Tabs():
472
+ with gr.Tab("Markdown Render"):
473
+ md_render = gr.Markdown(
474
+ value="Upload and parse to view results", height=520
475
+ )
476
+ with gr.Tab("Raw Markdown"):
477
+ md_raw = gr.Textbox(value="", lines=20)
478
+ with gr.Tab("Current Page JSON"):
479
+ json_view = gr.JSON(value=None)
480
+ with gr.Tab("Processed Image"):
481
+ processed_view = gr.Image(type="pil", height=520)
482
+
483
+ with gr.Row():
484
+ download_jsonl = gr.DownloadButton(label="Download JSONL")
485
+ download_markdown = gr.DownloadButton(label="Download Markdown")
486
+
487
+ def on_template_change(choice: str) -> str:
488
+ return dict_promptmode_to_prompt.get("prompt_layout_all_en", "")
489
+
490
+ def on_file_change(path: Optional[str]):
491
+ if not path or not os.path.exists(path):
492
+ return (
493
+ {
494
+ "images": [],
495
+ "current_page": 0,
496
+ "total_pages": 0,
497
+ "file_type": None,
498
+ "checksum": None,
499
+ "results": [],
500
+ "parsed": False,
501
+ },
502
+ None,
503
+ '<div class="page-info">No file</div>',
504
+ )
505
+ checksum = file_checksum(path)
506
+ ext = os.path.splitext(path)[1].lower()
507
+ if ext == ".pdf":
508
+ images = load_images_from_pdf(path)
509
+ state = {
510
+ "images": images,
511
+ "current_page": 0,
512
+ "total_pages": len(images),
513
+ "file_type": "pdf",
514
+ "checksum": checksum,
515
+ "results": [None] * len(images),
516
+ "parsed": False,
517
+ }
518
+ return (
519
+ state,
520
+ images[0] if images else None,
521
+ f'<div class="page-info">Page 1 / {len(images)}</div>',
522
+ )
523
+ else:
524
+ image = Image.open(path).convert("RGB")
525
+ state = {
526
+ "images": [image],
527
+ "current_page": 0,
528
+ "total_pages": 1,
529
+ "file_type": "image",
530
+ "checksum": checksum,
531
+ "results": [None],
532
+ "parsed": False,
533
+ }
534
+ return state, image, '<div class="page-info">Page 1 / 1</div>'
535
+
536
+ def nav_page(state: Dict[str, Any], direction: str):
537
+ if not state.get("images"):
538
+ return (
539
+ state,
540
+ None,
541
+ '<div class="page-info">No file</div>',
542
+ "No results",
543
+ "",
544
+ None,
545
+ None,
546
+ )
547
+ if direction == "prev":
548
+ state["current_page"] = max(0, state["current_page"] - 1)
549
+ elif direction == "next":
550
+ state["current_page"] = min(
551
+ state["total_pages"] - 1, state["current_page"] + 1
552
+ )
553
+ idx = state["current_page"]
554
+ img = state["images"][idx]
555
+ info = (
556
+ f'<div class="page-info">Page {idx + 1} / {state["total_pages"]}</div>'
557
+ )
558
+ result = (
559
+ state["results"][idx]
560
+ if state.get("parsed") and idx < len(state["results"])
561
+ else None
562
+ )
563
+ md = result.get("markdown") if result else "Page not processed yet"
564
+ md_out = gr.update(value=md, rtl=True) if is_arabic_text(md) else md
565
+ md_raw_text = md
566
+ proc_img = result.get("processed_image") if result else None
567
+ js = result.get("layout_result") if result else None
568
+ return state, img, info, md_out, md_raw_text, proc_img, js
569
+
570
+ def jump_to_page(state: Dict[str, Any], page_num: Any):
571
+ if not state.get("images"):
572
+ return (
573
+ state,
574
+ None,
575
+ '<div class="page-info">No file</div>',
576
+ "No results",
577
+ "",
578
+ None,
579
+ None,
580
+ )
581
+ try:
582
+ n = int(page_num)
583
+ except Exception:
584
+ n = 1
585
+ n = max(1, min(state["total_pages"], n))
586
+ state["current_page"] = n - 1
587
+ return nav_page(state, direction="stay")
588
+
589
+ def parse_pages(
590
+ state: Dict[str, Any],
591
+ prompt: str,
592
+ max_tokens: int,
593
+ selection: Optional[str],
594
+ ):
595
+ if not state.get("images"):
596
+ return state, None, "No file", "No content", "", None, None
597
+
598
+ indices: List[int] = []
599
+ if not selection or selection.strip() == "":
600
+ indices = [state["current_page"]]
601
+ elif selection.strip().lower() == "all":
602
+ indices = list(range(state["total_pages"]))
603
+ else:
604
+ parts = [p.strip() for p in selection.split(",") if p.strip()]
605
+ for p in parts:
606
+ if "-" in p:
607
+ a, b = p.split("-", 1)
608
+ try:
609
+ a_i = max(1, int(a))
610
+ b_i = min(state["total_pages"], int(b))
611
+ for i in range(a_i - 1, b_i):
612
+ indices.append(i)
613
+ except Exception:
614
+ continue
615
+ else:
616
+ try:
617
+ i = max(1, min(state["total_pages"], int(p)))
618
+ indices.append(i - 1)
619
+ except Exception:
620
+ continue
621
+ indices = sorted(
622
+ set([i for i in indices if 0 <= i < state["total_pages"]])
623
+ )
624
+
625
+ results = state.get("results") or [None] * state["total_pages"]
626
+ for i in indices:
627
+ img = state["images"][i]
628
+ prompt_hash = hashlib.sha256(prompt.encode("utf-8")).hexdigest()[:16]
629
+ cache_key = (
630
+ state["checksum"],
631
+ i,
632
+ prompt_hash,
633
+ int(max_tokens),
634
+ )
635
+ cached = cache_state.value.get(cache_key)
636
+ if cached:
637
+ results[i] = cached
638
+ continue
639
+ res = process_single_image(
640
+ img,
641
+ prompt_text=prompt,
642
+ max_new_tokens=int(max_tokens),
643
+ )
644
+ results[i] = res
645
+ cache_state.value[cache_key] = res
646
+ state["results"] = results
647
+ state["parsed"] = True
648
+
649
+ idx = state["current_page"]
650
+ curr = results[idx]
651
+ md = curr.get("markdown") if curr else "No content"
652
+ md_out = gr.update(value=md, rtl=True) if is_arabic_text(md) else md
653
+ md_raw_text = md
654
+ proc_img = curr.get("processed_image") if curr else None
655
+ js = curr.get("layout_result") if curr else None
656
+ info = (
657
+ f'<div class="page-info">Page {idx + 1} / {state["total_pages"]}</div>'
658
+ )
659
+ prev = state["images"][idx]
660
+ return state, prev, info, md_out, md_raw_text, proc_img, js
661
+
662
+ def clear_all():
663
+ gc.collect()
664
+ return (
665
+ {
666
+ "images": [],
667
+ "current_page": 0,
668
+ "total_pages": 0,
669
+ "file_type": None,
670
+ "checksum": None,
671
+ "results": [],
672
+ "parsed": False,
673
+ },
674
+ None,
675
+ '<div class="page-info">No file</div>',
676
+ "Upload and parse to view results",
677
+ "",
678
+ None,
679
+ None,
680
+ )
681
+
682
+ def download_current_jsonl(state: Dict[str, Any]):
683
+ if not state.get("parsed"):
684
+ return gr.DownloadButton.update(value=b"")
685
+ lines: List[str] = []
686
+ for i, res in enumerate(state.get("results", [])):
687
+ if res and res.get("layout_result") is not None:
688
+ obj = {"page": i + 1, "layout": res["layout_result"]}
689
+ lines.append(json.dumps(obj, ensure_ascii=False))
690
+ content = "\n".join(lines) if lines else ""
691
+ out_path = os.path.join(TMP_DIR, "results.jsonl")
692
+ with open(out_path, "w", encoding="utf-8") as f:
693
+ f.write(content)
694
+ return gr.DownloadButton.update(value=out_path)
695
+
696
+ def download_current_markdown(state: Dict[str, Any]):
697
+ if not state.get("parsed"):
698
+ return gr.DownloadButton.update(value=b"")
699
+ chunks: List[str] = []
700
+ for i, res in enumerate(state.get("results", [])):
701
+ if res and res.get("markdown"):
702
+ chunks.append(f"## Page {i + 1}\n\n{res['markdown']}")
703
+ content = "\n\n---\n\n".join(chunks) if chunks else ""
704
+ out_path = os.path.join(TMP_DIR, "results.md")
705
+ with open(out_path, "w", encoding="utf-8") as f:
706
+ f.write(content)
707
+ return gr.DownloadButton.update(value=out_path)
708
+
709
+ template.change(on_template_change, inputs=[template], outputs=[prompt_text])
710
+ file_input.change(
711
+ on_file_change,
712
+ inputs=[file_input],
713
+ outputs=[doc_state, preview_image, page_info],
714
+ )
715
+ prev_btn.click(
716
+ lambda s: nav_page(s, "prev"),
717
+ inputs=[doc_state],
718
+ outputs=[
719
+ doc_state,
720
+ preview_image,
721
+ page_info,
722
+ md_render,
723
+ md_raw,
724
+ processed_view,
725
+ json_view,
726
+ ],
727
+ )
728
+ next_btn.click(
729
+ lambda s: nav_page(s, "next"),
730
+ inputs=[doc_state],
731
+ outputs=[
732
+ doc_state,
733
+ preview_image,
734
+ page_info,
735
+ md_render,
736
+ md_raw,
737
+ processed_view,
738
+ json_view,
739
+ ],
740
+ )
741
+ jump_btn.click(
742
+ jump_to_page,
743
+ inputs=[doc_state, page_jump],
744
+ outputs=[
745
+ doc_state,
746
+ preview_image,
747
+ page_info,
748
+ md_render,
749
+ md_raw,
750
+ processed_view,
751
+ json_view,
752
+ ],
753
+ )
754
+ parse_button.click(
755
+ parse_pages,
756
+ inputs=[doc_state, prompt_text, max_new_tokens, page_range],
757
+ outputs=[
758
+ doc_state,
759
+ preview_image,
760
+ page_info,
761
+ md_render,
762
+ md_raw,
763
+ processed_view,
764
+ json_view,
765
+ ],
766
+ )
767
+ clear_button.click(
768
+ clear_all,
769
+ outputs=[
770
+ doc_state,
771
+ preview_image,
772
+ page_info,
773
+ md_render,
774
+ md_raw,
775
+ processed_view,
776
+ json_view,
777
+ ],
778
+ )
779
+ download_jsonl.click(
780
+ download_current_jsonl, inputs=[doc_state], outputs=[download_jsonl]
781
+ )
782
+ download_markdown.click(
783
+ download_current_markdown, inputs=[doc_state], outputs=[download_markdown]
784
+ )
785
+
786
+ return demo