xfey commited on
Commit
dfb1341
·
1 Parent(s): a75eb7f

[init] update application file

Browse files
.gitignore ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python相关
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+ MANIFEST
23
+
24
+ # 环境文件
25
+ .env
26
+ .venv
27
+ env/
28
+ venv/
29
+ ENV/
30
+ env.bak/
31
+ venv.bak/
32
+
33
+ # 编辑器文件
34
+ .vscode/
35
+ .idea/
36
+ *.suo
37
+ *.ntvs*
38
+ *.njsproj
39
+ *.sln
40
+ *.sw?
41
+
42
+ # 日志和数据库
43
+ *.log
44
+ *.sqlite
45
+ *.db
46
+
47
+ # 系统文件
48
+ .DS_Store
49
+ Thumbs.db
50
+
51
+ # 测试相关
52
+ htmlcov/
53
+ .tox/
54
+ .coverage
55
+ .coverage.*
56
+ .cache
57
+ nosetests.xml
58
+ coverage.xml
59
+ *.cover
60
+ .hypothesis/
61
+ .pytest_cache/
62
+
63
+ # 输出文件
64
+ *.csv
65
+ *.json
66
+ *.xlsx
67
+ # *.pdf
68
+ out/
69
+ output/
70
+
71
+ # Jupyter笔记本
72
+ .ipynb_checkpoints
app.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import time
4
+ import uuid
5
+
6
+ import cv2
7
+ import gradio as gr
8
+ import pymupdf
9
+ import spaces
10
+ import torch
11
+ from gradio_pdf import PDF
12
+ from loguru import logger
13
+ from PIL import Image
14
+ from transformers import AutoProcessor, VisionEncoderDecoderModel
15
+
16
+ from utils.utils import prepare_image, parse_layout_string, process_coordinates, ImageDimensions
17
+
18
+ # 读取外部CSS文件
19
+ def load_css():
20
+ css_path = os.path.join(os.path.dirname(__file__), "static", "styles.css")
21
+ if os.path.exists(css_path):
22
+ with open(css_path, "r", encoding="utf-8") as f:
23
+ return f.read()
24
+ return ""
25
+
26
+ # 全局变量存储模型
27
+ model = None
28
+ processor = None
29
+ tokenizer = None
30
+
31
+ # 自动初始化模型
32
+ @spaces.GPU
33
+ def initialize_model():
34
+ """初始化 Hugging Face 模型"""
35
+ global model, processor, tokenizer
36
+
37
+ if model is None:
38
+ logger.info("Loading DOLPHIN model...")
39
+ model_id = "ByteDance/Dolphin"
40
+
41
+ # 加载处理器和模型
42
+ processor = AutoProcessor.from_pretrained(model_id)
43
+ model = VisionEncoderDecoderModel.from_pretrained(model_id)
44
+ model.eval()
45
+
46
+ # 设置设备和精度
47
+ device = "cuda" if torch.cuda.is_available() else "cpu"
48
+ model.to(device)
49
+ model = model.half() # 使用半精度
50
+
51
+ # 设置tokenizer
52
+ tokenizer = processor.tokenizer
53
+
54
+ logger.info(f"Model loaded successfully on {device}")
55
+
56
+ return "Model ready"
57
+
58
+ # 启动时自动初始化模型
59
+ logger.info("Initializing model at startup...")
60
+ try:
61
+ initialize_model()
62
+ logger.info("Model initialization completed")
63
+ except Exception as e:
64
+ logger.error(f"Model initialization failed: {e}")
65
+ # 模型将在首次使用时重新尝试初始化
66
+
67
+ # 模型推理函数
68
+ @spaces.GPU
69
+ def model_chat(prompt, image):
70
+ """使用模型进行推理"""
71
+ global model, processor, tokenizer
72
+
73
+ # 确保模型已初始化
74
+ if model is None:
75
+ initialize_model()
76
+
77
+ # 检查是否为批处理
78
+ is_batch = isinstance(image, list)
79
+
80
+ if not is_batch:
81
+ images = [image]
82
+ prompts = [prompt]
83
+ else:
84
+ images = image
85
+ prompts = prompt if isinstance(prompt, list) else [prompt] * len(images)
86
+
87
+ # 准备图像
88
+ device = "cuda" if torch.cuda.is_available() else "cpu"
89
+ batch_inputs = processor(images, return_tensors="pt", padding=True)
90
+ batch_pixel_values = batch_inputs.pixel_values.half().to(device)
91
+
92
+ # 准备提示
93
+ prompts = [f"<s>{p} <Answer/>" for p in prompts]
94
+ batch_prompt_inputs = tokenizer(
95
+ prompts,
96
+ add_special_tokens=False,
97
+ return_tensors="pt"
98
+ )
99
+
100
+ batch_prompt_ids = batch_prompt_inputs.input_ids.to(device)
101
+ batch_attention_mask = batch_prompt_inputs.attention_mask.to(device)
102
+
103
+ # 生成文本
104
+ outputs = model.generate(
105
+ pixel_values=batch_pixel_values,
106
+ decoder_input_ids=batch_prompt_ids,
107
+ decoder_attention_mask=batch_attention_mask,
108
+ min_length=1,
109
+ max_length=4096,
110
+ pad_token_id=tokenizer.pad_token_id,
111
+ eos_token_id=tokenizer.eos_token_id,
112
+ use_cache=True,
113
+ bad_words_ids=[[tokenizer.unk_token_id]],
114
+ return_dict_in_generate=True,
115
+ do_sample=False,
116
+ num_beams=1,
117
+ repetition_penalty=1.1
118
+ )
119
+
120
+ # 处理输出
121
+ sequences = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)
122
+
123
+ # 清理提示文本
124
+ results = []
125
+ for i, sequence in enumerate(sequences):
126
+ cleaned = sequence.replace(prompts[i], "").replace("<pad>", "").replace("</s>", "").strip()
127
+ results.append(cleaned)
128
+
129
+ # 返回单个结果或批处理结果
130
+ if not is_batch:
131
+ return results[0]
132
+ return results
133
+
134
+ # 处理元素批次
135
+ @spaces.GPU
136
+ def process_element_batch(elements, prompt, max_batch_size=16):
137
+ """处理同类型元素的批次"""
138
+ results = []
139
+
140
+ # 确定批次大小
141
+ batch_size = min(len(elements), max_batch_size)
142
+
143
+ # 分批处理
144
+ for i in range(0, len(elements), batch_size):
145
+ batch_elements = elements[i:i+batch_size]
146
+ crops_list = [elem["crop"] for elem in batch_elements]
147
+
148
+ # 使用相同的提示
149
+ prompts_list = [prompt] * len(crops_list)
150
+
151
+ # 批量推理
152
+ batch_results = model_chat(prompts_list, crops_list)
153
+
154
+ # 添加结果
155
+ for j, result in enumerate(batch_results):
156
+ elem = batch_elements[j]
157
+ results.append({
158
+ "label": elem["label"],
159
+ "bbox": elem["bbox"],
160
+ "text": result.strip(),
161
+ "reading_order": elem["reading_order"],
162
+ })
163
+
164
+ return results
165
+
166
+ # 清理临时文件
167
+ def cleanup_temp_file(file_path):
168
+ """安全地删除临时文件"""
169
+ try:
170
+ if file_path and os.path.exists(file_path):
171
+ os.unlink(file_path)
172
+ except Exception as e:
173
+ logger.warning(f"Failed to cleanup temp file {file_path}: {e}")
174
+
175
+ def to_pdf(file_path):
176
+ """将输入文件转换为PDF格式"""
177
+ if file_path is None:
178
+ return None
179
+
180
+ with pymupdf.open(file_path) as f:
181
+ if f.is_pdf:
182
+ return file_path
183
+ else:
184
+ pdf_bytes = f.convert_to_pdf()
185
+ # 使用临时文件而不是保存到磁盘
186
+ with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp_file:
187
+ tmp_file.write(pdf_bytes)
188
+ return tmp_file.name
189
+
190
+ @spaces.GPU(duration=120)
191
+ def process_document(file_path):
192
+ """处理文档的主要函数 - 集成完整的推理逻辑"""
193
+ if file_path is None:
194
+ return "", "", {}, {}
195
+
196
+ start_time = time.time()
197
+ original_file_path = file_path
198
+
199
+ # 确保模型已初始化
200
+ if model is None:
201
+ initialize_model()
202
+
203
+ # 转换为PDF(如果需要)
204
+ converted_file_path = to_pdf(file_path)
205
+ temp_file_created = converted_file_path != original_file_path
206
+
207
+ try:
208
+ logger.info(f"Processing document: {file_path}")
209
+
210
+ # 处理页面
211
+ recognition_results = process_page(converted_file_path)
212
+
213
+ # 生成Markdown内容
214
+ md_content = generate_markdown(recognition_results)
215
+
216
+ # 计算处理时间
217
+ processing_time = time.time() - start_time
218
+
219
+ debug_info = {
220
+ "original_file": original_file_path,
221
+ "converted_file": converted_file_path,
222
+ "temp_file_created": temp_file_created,
223
+ "status": "success",
224
+ "processing_time": f"{processing_time:.2f}s",
225
+ "total_elements": len(recognition_results)
226
+ }
227
+
228
+ processing_data = {
229
+ "pages": [{"elements": recognition_results}],
230
+ "total_elements": len(recognition_results),
231
+ "processing_time": f"{processing_time:.2f}s"
232
+ }
233
+
234
+ logger.info(f"Document processed successfully in {processing_time:.2f}s")
235
+ return md_content, md_content, processing_data, debug_info
236
+
237
+ except Exception as e:
238
+ logger.error(f"Error processing document: {str(e)}")
239
+ error_info = {
240
+ "original_file": original_file_path,
241
+ "converted_file": converted_file_path,
242
+ "temp_file_created": temp_file_created,
243
+ "status": "error",
244
+ "error": str(e)
245
+ }
246
+ return f"# 处理错误\n\n处理文档时发生错误: {str(e)}", "", {}, error_info
247
+
248
+ finally:
249
+ # 清理临时文件
250
+ if temp_file_created:
251
+ cleanup_temp_file(converted_file_path)
252
+
253
+ def process_page(image_path):
254
+ """处理单页文档"""
255
+ # 阶段1: 页面级布局解析
256
+ pil_image = Image.open(image_path).convert("RGB")
257
+ layout_output = model_chat("Parse the reading order of this document.", pil_image)
258
+
259
+ # 阶段2: 元素级内容解析
260
+ padded_image, dims = prepare_image(pil_image)
261
+ recognition_results = process_elements(layout_output, padded_image, dims)
262
+
263
+ return recognition_results
264
+
265
+ def process_elements(layout_results, padded_image, dims, max_batch_size=16):
266
+ """解析所有文档元素"""
267
+ layout_results = parse_layout_string(layout_results)
268
+
269
+ # 分别存储不同类型的元素
270
+ text_elements = [] # 文本元素
271
+ table_elements = [] # 表格元素
272
+ figure_results = [] # 图像元素(无需处理)
273
+ previous_box = None
274
+ reading_order = 0
275
+
276
+ # 收集要处理的元素并按类型分组
277
+ for bbox, label in layout_results:
278
+ try:
279
+ # 调整坐标
280
+ x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = process_coordinates(
281
+ bbox, padded_image, dims, previous_box
282
+ )
283
+
284
+ # 裁剪并解析元素
285
+ cropped = padded_image[y1:y2, x1:x2]
286
+ if cropped.size > 0:
287
+ if label == "fig":
288
+ # 对于图像区域,直接添加空文本结果
289
+ figure_results.append(
290
+ {
291
+ "label": label,
292
+ "bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
293
+ "text": "",
294
+ "reading_order": reading_order,
295
+ }
296
+ )
297
+ else:
298
+ # 准备元素进行解析
299
+ pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
300
+ element_info = {
301
+ "crop": pil_crop,
302
+ "label": label,
303
+ "bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
304
+ "reading_order": reading_order,
305
+ }
306
+
307
+ # 按类型分组
308
+ if label == "tab":
309
+ table_elements.append(element_info)
310
+ else: # 文本元素
311
+ text_elements.append(element_info)
312
+
313
+ reading_order += 1
314
+
315
+ except Exception as e:
316
+ logger.error(f"Error processing bbox with label {label}: {str(e)}")
317
+ continue
318
+
319
+ # 初始化结果列表
320
+ recognition_results = figure_results.copy()
321
+
322
+ # 处理文本元素(批量)
323
+ if text_elements:
324
+ text_results = process_element_batch(text_elements, "Read text in the image.", max_batch_size)
325
+ recognition_results.extend(text_results)
326
+
327
+ # 处理表格元素(批量)
328
+ if table_elements:
329
+ table_results = process_element_batch(table_elements, "Parse the table in the image.", max_batch_size)
330
+ recognition_results.extend(table_results)
331
+
332
+ # 按阅读顺序排序
333
+ recognition_results.sort(key=lambda x: x.get("reading_order", 0))
334
+
335
+ return recognition_results
336
+
337
+ def generate_markdown(recognition_results):
338
+ """从识别结果生成Markdown内容"""
339
+ markdown_parts = []
340
+
341
+ for result in recognition_results:
342
+ text = result.get("text", "").strip()
343
+ label = result.get("label", "")
344
+
345
+ if text:
346
+ if label == "tab":
347
+ # 表格内容
348
+ markdown_parts.append(f"\n{text}\n")
349
+ else:
350
+ # 普通文本内容
351
+ markdown_parts.append(text)
352
+
353
+ return "\n\n".join(markdown_parts)
354
+
355
+ # LaTeX 渲染配置
356
+ latex_delimiters = [
357
+ {"left": "$$", "right": "$$", "display": True},
358
+ {"left": "$", "right": "$", "display": False},
359
+ {"left": "\\[", "right": "\\]", "display": True},
360
+ {"left": "\\(", "right": "\\)", "display": False},
361
+ ]
362
+
363
+ # 加载自定义CSS
364
+ custom_css = load_css()
365
+
366
+ # 读取页面头部
367
+ with open("header.html", "r", encoding="utf-8") as file:
368
+ header = file.read()
369
+
370
+ # 创建 Gradio 界面
371
+ with gr.Blocks(css=custom_css, title="Dolphin Document Parser") as demo:
372
+ gr.HTML(header)
373
+
374
+ with gr.Row():
375
+ # 侧边栏 - 文件上传和控制
376
+ with gr.Column(scale=1, elem_classes="sidebar"):
377
+ # 文件上传组件
378
+ file = gr.File(
379
+ label="Choose PDF or image file",
380
+ file_types=[".pdf", ".png", ".jpeg", ".jpg"],
381
+ elem_id="file-upload"
382
+ )
383
+
384
+ gr.HTML("选择文件后,点击处理按钮开始解析<br>After selecting the file, click the Process button to start parsing")
385
+
386
+ with gr.Row(elem_classes="action-buttons"):
387
+ submit_btn = gr.Button("处理文档/Process Document", variant="primary")
388
+ clear_btn = gr.ClearButton(value="清空/Clear")
389
+
390
+ # 处理状态显示
391
+ status_display = gr.Textbox(
392
+ label="Processing Status",
393
+ value="Ready to process documents",
394
+ interactive=False,
395
+ max_lines=2
396
+ )
397
+
398
+ # 示例文件
399
+ example_root = os.path.join(os.path.dirname(__file__), "examples")
400
+ if os.path.exists(example_root):
401
+ gr.HTML("示例文件/Example Files")
402
+ example_files = [
403
+ os.path.join(example_root, f)
404
+ for f in os.listdir(example_root)
405
+ if not f.endswith(".py")
406
+ ]
407
+
408
+ examples = gr.Examples(
409
+ examples=example_files,
410
+ inputs=file,
411
+ examples_per_page=10,
412
+ elem_id="example-files"
413
+ )
414
+
415
+ # 主体内容区域
416
+ with gr.Column(scale=7):
417
+ with gr.Row(elem_classes="main-content"):
418
+ # 预览面板
419
+ with gr.Column(scale=1, elem_classes="preview-panel"):
420
+ gr.HTML("文件预览/Preview")
421
+ pdf_show = PDF(label="", interactive=False, visible=True, height=600)
422
+ debug_output = gr.JSON(label="Debug Info", height=100)
423
+
424
+ # 输出面板
425
+ with gr.Column(scale=1, elem_classes="output-panel"):
426
+ with gr.Tabs():
427
+ with gr.Tab("Markdown [Render]"):
428
+ md_render = gr.Markdown(
429
+ label="",
430
+ height=700,
431
+ show_copy_button=True,
432
+ latex_delimiters=latex_delimiters,
433
+ line_breaks=True,
434
+ )
435
+ with gr.Tab("Markdown [Content]"):
436
+ md_content = gr.TextArea(lines=30, show_copy_button=True)
437
+ with gr.Tab("Processing Data"):
438
+ json_output = gr.JSON(label="", height=700)
439
+
440
+ # 事件处理
441
+ file.change(fn=to_pdf, inputs=file, outputs=pdf_show)
442
+
443
+ # 文档处理
444
+ def process_with_status(file_path):
445
+ """处理文档并更新状态"""
446
+ if file_path is None:
447
+ return "", "", {}, {}, "Please select a file first"
448
+
449
+ # 更新状态为处理中
450
+ status = "Processing document..."
451
+
452
+ # 执行文档处理
453
+ md_render_result, md_content_result, json_result, debug_result = process_document(file_path)
454
+
455
+ # 更新完成状态
456
+ if "错误" in md_render_result:
457
+ status = "Processing failed - see debug info"
458
+ else:
459
+ status = "Processing completed successfully"
460
+
461
+ return md_render_result, md_content_result, json_result, debug_result, status
462
+
463
+ submit_btn.click(
464
+ fn=process_with_status,
465
+ inputs=[file],
466
+ outputs=[md_render, md_content, json_output, debug_output, status_display],
467
+ )
468
+
469
+ # 清空所有内容
470
+ def reset_all():
471
+ return None, None, "", "", {}, {}, "Ready to process documents"
472
+
473
+ clear_btn.click(
474
+ fn=reset_all,
475
+ inputs=[],
476
+ outputs=[file, pdf_show, md_render, md_content, json_output, debug_output, status_display]
477
+ )
478
+
479
+ # 启动应用
480
+ if __name__ == "__main__":
481
+ demo.launch()
examples/page_1.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8984e6b0bffa46e13809b4969e2be559df89e2cf9d6b3d7fb1a78f25aed8e570
3
+ size 1523572
examples/page_2.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c4f4785470676739e2998f04bfc8daaf2e7ae227bf374614f07821ec5a315143
3
+ size 1478409
examples/page_3.jpeg ADDED

Git LFS Details

  • SHA256: fe6e35a3c888c77ec36cf48cb762556e489e288d30a457a353ac6bba6fab9251
  • Pointer size: 131 Bytes
  • Size of remote file: 449 kB
examples/page_4.png ADDED

Git LFS Details

  • SHA256: 497cdabe38a4db8318284c0f8963304a876ceceebb796059903703834e4713ed
  • Pointer size: 131 Bytes
  • Size of remote file: 372 kB
examples/page_5.jpg ADDED

Git LFS Details

  • SHA256: 17cdc261fcd7eb8db4a0bdfb56dc2b1f77c8890956f8451f810695e115f6f894
  • Pointer size: 131 Bytes
  • Size of remote file: 641 kB
examples/page_6.jpg ADDED

Git LFS Details

  • SHA256: 0e4dfe55790db38d64ff0d4cf2707859e2d17d4c6e254e398fa21ab4239fd6ec
  • Pointer size: 131 Bytes
  • Size of remote file: 975 kB
header.html ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en" style="color-scheme: light;">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <meta name="color-scheme" content="light">
7
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css">
8
+ <style>
9
+ :root {
10
+ /* 主色调 */
11
+ --primary-color: #dceaf6;
12
+ --primary-light: #f8f9fa;
13
+ --primary-dark: #9ec9e3;
14
+
15
+ /* 辅助色调 */
16
+ --accent-color: #bfe2f8;
17
+ --accent-light: #dceaf6;
18
+
19
+ /* 背景色 */
20
+ --bg-color: #e8eff5;
21
+ --card-bg: #ffffff;
22
+
23
+ /* 文本色 */
24
+ --dark-text: #2b2d42;
25
+ --light-text: #f8f9fa;
26
+ --muted-text: rgba(43, 45, 66, 0.7);
27
+
28
+ /* 边框和阴影 */
29
+ --border-color: rgba(168, 168, 168, 0.432);
30
+ --card-shadow: 0 4px 20px rgba(104, 104, 104, 0.1);
31
+
32
+ /* 交互状态 */
33
+ --hover-bg: rgba(255, 255, 255, 0.5);
34
+ --active-color: #bfe2f8;
35
+ }
36
+
37
+ .header-container {
38
+ display: flex;
39
+ flex-direction: row;
40
+ justify-content: space-between;
41
+ align-items: flex-start;
42
+ background: linear-gradient(135deg,
43
+ #e4deff 0%,
44
+ #d8f7ff 100%
45
+ );
46
+ padding: 1.8rem;
47
+ border-radius: 12px;
48
+ margin-bottom: 1.5rem;
49
+ box-shadow: var(--card-shadow);
50
+ position: relative;
51
+ overflow: hidden;
52
+ }
53
+
54
+ .header-container::before {
55
+ content: '';
56
+ position: absolute;
57
+ top: 0;
58
+ left: 0;
59
+ right: 0;
60
+ bottom: 0;
61
+ background: linear-gradient(135deg,
62
+ rgba(255, 255, 255, 0.2) 0%,
63
+ rgba(255, 255, 255, 0) 100%
64
+ );
65
+ pointer-events: none;
66
+ }
67
+
68
+ .header-content {
69
+ display: flex;
70
+ flex-direction: column;
71
+ align-items: center;
72
+ text-align: center;
73
+ max-width: 100%;
74
+ width: 100%;
75
+ }
76
+
77
+ .header-buttons {
78
+ display: none;
79
+ }
80
+
81
+ .logo-title-container {
82
+ display: flex;
83
+ flex-direction: column;
84
+ align-items: center;
85
+ margin-bottom: 1.5rem;
86
+ max-width: 100%;
87
+ text-align: center;
88
+ }
89
+
90
+ .logo {
91
+ width: 350px;
92
+ height: auto;
93
+ margin-bottom: 1rem;
94
+ margin-right: 0;
95
+ }
96
+
97
+ .header-title {
98
+ font-size: 2.2rem;
99
+ font-weight: 700;
100
+ color: var(--dark-text);
101
+ margin: 0;
102
+ font-family: 'Poppins', 'Segoe UI', sans-serif;
103
+ line-height: 1.2;
104
+ text-align: center;
105
+ max-width: 100%;
106
+ }
107
+
108
+ .header-subtitle {
109
+ font-size: 1.1rem;
110
+ color: var(--muted-text);
111
+ margin: 0 0 1.5rem 0;
112
+ line-height: 1.6;
113
+ max-width: 100%;
114
+ text-align: center;
115
+ margin-left: auto;
116
+ margin-right: auto;
117
+ }
118
+
119
+ .link-button {
120
+ display: flex;
121
+ align-items: center;
122
+ padding: 0.7rem 1.2rem;
123
+ background-color: var(--hover-bg);
124
+ border-radius: 8px;
125
+ color: var(--dark-text) !important;
126
+ text-decoration: none !important;
127
+ font-weight: 700;
128
+ font-size: 1.1rem;
129
+ transition: all 0.3s ease;
130
+ backdrop-filter: blur(5px);
131
+ border: 1px solid var(--border-color);
132
+ width: 100%;
133
+ margin-bottom: 0.5rem;
134
+ }
135
+
136
+ .link-button:hover {
137
+ background-color: var(--hover-bg);
138
+ transform: translateY(-2px);
139
+ box-shadow: 0 5px 15px rgba(0, 0, 0, 0.1);
140
+ text-decoration: none !important;
141
+ color: var(--dark-text) !important;
142
+ }
143
+
144
+ .link-button i {
145
+ margin-right: 0.8rem;
146
+ font-size: 1.2rem;
147
+ color: var(--primary-dark);
148
+ min-width: 20px;
149
+ text-align: center;
150
+ }
151
+
152
+ .link-button * {
153
+ text-decoration: none !important;
154
+ color: inherit !important;
155
+ }
156
+
157
+ .feature-grid {
158
+ display: flex;
159
+ flex-direction: row;
160
+ align-items: flex-start;
161
+ justify-content: center;
162
+ margin-top: 1.5rem;
163
+ width: 100%;
164
+ margin-left: auto;
165
+ margin-right: auto;
166
+ }
167
+
168
+ .feature-card {
169
+ flex: 1;
170
+ padding: 1rem 1rem;
171
+ background-color: transparent;
172
+ border: none;
173
+ box-shadow: none;
174
+ transition: none;
175
+ text-align: center;
176
+ position: relative;
177
+ }
178
+
179
+ .feature-card:hover {
180
+ transform: none;
181
+ box-shadow: none;
182
+ }
183
+
184
+ .feature-separator {
185
+ width: 1px;
186
+ align-self: stretch;
187
+ background-color: var(--border-color);
188
+ margin: 0 1rem;
189
+ }
190
+
191
+ .feature-icon {
192
+ font-size: 2rem;
193
+ color: var(--primary-dark);
194
+ margin-bottom: 1rem;
195
+ }
196
+
197
+ .feature-title {
198
+ font-weight: 600;
199
+ color: var(--dark-text);
200
+ margin-bottom: 0.8rem;
201
+ font-size: 1.2rem;
202
+ }
203
+
204
+ .feature-desc {
205
+ font-size: 0.85rem;
206
+ color: var(--muted-text);
207
+ line-height: 1.5;
208
+ }
209
+
210
+ /* 新的导航按钮样式 */
211
+ .nav-buttons {
212
+ display: flex;
213
+ flex-direction: row;
214
+ align-items: center;
215
+ justify-content: center;
216
+ margin-top: 1rem;
217
+ margin-bottom: 2rem;
218
+ background-color: rgba(255, 255, 255, 0.7);
219
+ border-radius: 12px;
220
+ border: 1px solid var(--border-color);
221
+ padding: 0.5rem 1rem;
222
+ max-width: none;
223
+ width: auto;
224
+ align-self: center;
225
+ margin-left: auto;
226
+ margin-right: auto;
227
+ }
228
+
229
+ .nav-link {
230
+ display: flex;
231
+ align-items: center;
232
+ padding: 0.5rem 1rem;
233
+ color: var(--dark-text) !important;
234
+ text-decoration: none !important;
235
+ font-weight: 600;
236
+ font-size: 1rem;
237
+ transition: all 0.3s ease;
238
+ }
239
+
240
+ .nav-link:hover {
241
+ transform: translateY(-3px);
242
+ color: var(--primary-dark) !important;
243
+ background-color: rgba(255, 255, 255, 0.8);
244
+ }
245
+
246
+ .nav-link i {
247
+ margin-right: 0.5rem;
248
+ font-size: 1.1rem;
249
+ color: var(--primary-dark);
250
+ }
251
+
252
+ .nav-separator {
253
+ height: 20px;
254
+ width: 1px;
255
+ background-color: var(--border-color);
256
+ margin: 0 0.5rem;
257
+ }
258
+
259
+ @media (max-width: 960px) {
260
+ .header-container {
261
+ flex-direction: column;
262
+ padding: 1.5rem;
263
+ }
264
+
265
+ .header-content {
266
+ max-width: 100%;
267
+ margin-bottom: 2rem;
268
+ }
269
+
270
+ .header-buttons {
271
+ width: 100%;
272
+ margin-left: 0;
273
+ }
274
+
275
+ .logo-title-container {
276
+ flex-direction: column;
277
+ align-items: center;
278
+ }
279
+
280
+ .logo {
281
+ width: 250px;
282
+ margin-bottom: 1rem;
283
+ margin-right: 0;
284
+ }
285
+
286
+ .header-title {
287
+ font-size: 1.8rem;
288
+ }
289
+
290
+ .feature-grid {
291
+ flex-direction: column;
292
+ }
293
+
294
+ .feature-card {
295
+ width: 100%;
296
+ padding: 1rem 0;
297
+ }
298
+
299
+ .feature-separator {
300
+ width: 100%;
301
+ height: 1px;
302
+ margin: 0.5rem 0;
303
+ }
304
+
305
+ .nav-buttons {
306
+ flex-wrap: wrap;
307
+ justify-content: center;
308
+ }
309
+
310
+ .nav-link {
311
+ padding: 0.5rem;
312
+ font-size: 0.9rem;
313
+ }
314
+
315
+ .nav-separator {
316
+ display: none;
317
+ }
318
+
319
+ .feature-desc {
320
+ font-size: 0.9rem;
321
+ }
322
+ }
323
+
324
+ /* 添加禁用夜间模式的样式 */
325
+ @media (prefers-color-scheme: dark) {
326
+ /* 强制使用明亮模式颜色 */
327
+ .header-container,
328
+ .header-content,
329
+ .logo-title-container,
330
+ .header-title,
331
+ .header-subtitle,
332
+ .nav-buttons,
333
+ .nav-link,
334
+ .feature-grid,
335
+ .feature-card,
336
+ .feature-title,
337
+ .feature-desc,
338
+ body,
339
+ * {
340
+ color-scheme: light !important; /* 强制使用明亮模式配色方案 */
341
+ color: var(--dark-text) !important; /* 强制使用黑色文本 */
342
+ background-color: initial; /* 保持原有背景色 */
343
+ }
344
+ }
345
+
346
+ /* 添加全局样式覆盖,确保所有文本都使用我们指定的颜色 */
347
+ body, p, h1, h2, h3, h4, h5, h6, span, div, a {
348
+ color: var(--dark-text) !important;
349
+ }
350
+
351
+ .feature-desc {
352
+ color: var(--muted-text) !important;
353
+ }
354
+
355
+ /* 确保图标颜色也不受夜间模式影响 */
356
+ .feature-icon i, .nav-link i {
357
+ color: var(--primary-dark) !important;
358
+ }
359
+
360
+ /* 导航链接悬停效果 */
361
+ .nav-link:hover {
362
+ color: var(--primary-dark) !important;
363
+ }
364
+ </style>
365
+ </head>
366
+ <body>
367
+ <div class="header-container">
368
+ <div class="header-content">
369
+ <div class="logo-title-container">
370
+ <img src="https://raw.githubusercontent.com/bytedance/Dolphin/master/assets/dolphin.png" alt="Dolphin Logo" class="logo">
371
+ <h1 class="header-title">Document Image Parsing via Heterogeneous Anchor Prompting</h1>
372
+ </div>
373
+
374
+ <p class="header-subtitle">
375
+ A novel multimodal document image parsing model, following an analyze-then-parse paradigm for parallel decoding
376
+ <!-- <br>
377
+ Stage 1: Comprehensive page-level layout analysis by generating element sequence in natural reading order.
378
+ <br>
379
+ Stage 2: Efficient parallel parsing of document elements using heterogeneous anchors and task-specific prompts. -->
380
+ </p>
381
+
382
+ <!-- 新的导航按钮 -->
383
+ <div class="nav-buttons">
384
+ <!-- <a href="https://mineru.org.cn/home?source=huggingface" class="nav-link">
385
+ <i class="fas fa-home"></i> 主页/Homepage
386
+ </a> -->
387
+ <!-- <div class="nav-separator"></div> -->
388
+ <a href="https://arxiv.org/abs/2505.14059" class="nav-link">
389
+ <i class="fas fa-file-alt"></i> 论文/Paper
390
+ </a>
391
+ <div class="nav-separator"></div>
392
+ <a href="https://huggingface.co/ByteDance/Dolphin" class="nav-link">
393
+ <i class="fas fa-cube"></i> 模型/Model
394
+ </a>
395
+ <div class="nav-separator"></div>
396
+ <a href="https://github.com/bytedance/Dolphin" class="nav-link">
397
+ <i class="fas fa-code"></i> 代码/Code
398
+ </a>
399
+ <div class="nav-separator"></div>
400
+ <a href="https://opensource.org/licenses/MIT" class="nav-link">
401
+ <i class="fas fa-balance-scale"></i> 许可证/License
402
+ </a>
403
+ </div>
404
+
405
+ <div class="feature-grid">
406
+ <div class="feature-card">
407
+ <div class="feature-icon"><i class="fas fa-file-import"></i></div>
408
+ <div class="feature-title">支持格式/Support Format</div>
409
+ <div class="feature-desc">支持多页PDF、单页图像<br>Multi-page PDF, single document image (JPEG/PNG)</div>
410
+ </div>
411
+
412
+ <div class="feature-separator"></div>
413
+
414
+ <div class="feature-card">
415
+ <div class="feature-icon"><i class="fas fa-feather-alt"></i></div>
416
+ <div class="feature-title">轻量级模型/Lightweight Model</div>
417
+ <div class="feature-desc">Dolphin模型参数量322M,高效易部署<br>Lightweight (322M) and efficient, easy to deploy</div>
418
+ </div>
419
+
420
+ <div class="feature-separator"></div>
421
+
422
+ <div class="feature-card">
423
+ <div class="feature-icon"><i class="fas fa-tasks"></i></div>
424
+ <div class="feature-title">并行解析/Parallel Parsing</div>
425
+ <div class="feature-desc">Dolphin并行解析多个文本块<br>Parsing several text blocks in a batch for speed up</div>
426
+ </div>
427
+
428
+ <div class="feature-separator"></div>
429
+
430
+ <div class="feature-card">
431
+ <div class="feature-icon"><i class="fas fa-superscript"></i></div>
432
+ <div class="feature-title">公式和表格/Formula and Table</div>
433
+ <div class="feature-desc">支持公式(LaTeX格式)、表格(HTML格式)输出<br>Support formulas (LaTeX format) and tables (HTML format)</div>
434
+ </div>
435
+ </div>
436
+
437
+ <!-- 添加免责声明 -->
438
+ <p style="
439
+ font-size: 0.8rem;
440
+ color: var(--muted-text) !important;
441
+ margin-top: 1.5rem;
442
+ text-align: center;
443
+ ">内容由 AI 生成,请仔细甄别</p>
444
+ </div>
445
+ </div>
446
+ </body>
447
+ </html>
inference_hugg.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
3
+ SPDX-License-Identifier: MIT
4
+ """
5
+
6
+ import argparse
7
+ import glob
8
+ import os
9
+
10
+ import cv2
11
+ import torch
12
+ from PIL import Image
13
+ from transformers import AutoProcessor, VisionEncoderDecoderModel
14
+
15
+ from utils.utils import *
16
+
17
+
18
+ class DOLPHIN:
19
+ def __init__(self, model_id_or_path):
20
+ """Initialize the Hugging Face model
21
+
22
+ Args:
23
+ model_id_or_path: Path to local model or Hugging Face model ID
24
+ """
25
+ # Load model from local path or Hugging Face hub
26
+ self.processor = AutoProcessor.from_pretrained(model_id_or_path)
27
+ self.model = VisionEncoderDecoderModel.from_pretrained(model_id_or_path)
28
+ self.model.eval()
29
+
30
+ # Set device and precision
31
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ self.model.to(self.device)
33
+ self.model = self.model.half() # Always use half precision by default
34
+
35
+ # set tokenizer
36
+ self.tokenizer = self.processor.tokenizer
37
+
38
+ def chat(self, prompt, image):
39
+ """Process an image or batch of images with the given prompt(s)
40
+
41
+ Args:
42
+ prompt: Text prompt or list of prompts to guide the model
43
+ image: PIL Image or list of PIL Images to process
44
+
45
+ Returns:
46
+ Generated text or list of texts from the model
47
+ """
48
+ # Check if we're dealing with a batch
49
+ is_batch = isinstance(image, list)
50
+
51
+ if not is_batch:
52
+ # Single image, wrap it in a list for consistent processing
53
+ images = [image]
54
+ prompts = [prompt]
55
+ else:
56
+ # Batch of images
57
+ images = image
58
+ prompts = prompt if isinstance(prompt, list) else [prompt] * len(images)
59
+
60
+ # Prepare image
61
+ batch_inputs = self.processor(images, return_tensors="pt", padding=True)
62
+ batch_pixel_values = batch_inputs.pixel_values.half().to(self.device)
63
+
64
+ # Prepare prompt
65
+ prompts = [f"<s>{p} <Answer/>" for p in prompts]
66
+ batch_prompt_inputs = self.tokenizer(
67
+ prompts,
68
+ add_special_tokens=False,
69
+ return_tensors="pt"
70
+ )
71
+
72
+ batch_prompt_ids = batch_prompt_inputs.input_ids.to(self.device)
73
+ batch_attention_mask = batch_prompt_inputs.attention_mask.to(self.device)
74
+
75
+ # Generate text
76
+ outputs = self.model.generate(
77
+ pixel_values=batch_pixel_values,
78
+ decoder_input_ids=batch_prompt_ids,
79
+ decoder_attention_mask=batch_attention_mask,
80
+ min_length=1,
81
+ max_length=4096,
82
+ pad_token_id=self.tokenizer.pad_token_id,
83
+ eos_token_id=self.tokenizer.eos_token_id,
84
+ use_cache=True,
85
+ bad_words_ids=[[self.tokenizer.unk_token_id]],
86
+ return_dict_in_generate=True,
87
+ do_sample=False,
88
+ num_beams=1,
89
+ repetition_penalty=1.1
90
+ )
91
+
92
+ # Process output
93
+ sequences = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)
94
+
95
+ # Clean prompt text from output
96
+ results = []
97
+ for i, sequence in enumerate(sequences):
98
+ cleaned = sequence.replace(prompts[i], "").replace("<pad>", "").replace("</s>", "").strip()
99
+ results.append(cleaned)
100
+
101
+ # Return a single result for single image input
102
+ if not is_batch:
103
+ return results[0]
104
+ return results
105
+
106
+
107
+ def process_page(image_path, model, save_dir, max_batch_size=None):
108
+ """Parse document images with two stages"""
109
+ # Stage 1: Page-level layout and reading order parsing
110
+ pil_image = Image.open(image_path).convert("RGB")
111
+ layout_output = model.chat("Parse the reading order of this document.", pil_image)
112
+
113
+ # Stage 2: Element-level content parsing
114
+ padded_image, dims = prepare_image(pil_image)
115
+ recognition_results = process_elements(layout_output, padded_image, dims, model, max_batch_size)
116
+
117
+ # Save outputs
118
+ json_path = save_outputs(recognition_results, image_path, save_dir)
119
+
120
+ return json_path, recognition_results
121
+
122
+
123
+ def process_elements(layout_results, padded_image, dims, model, max_batch_size=None):
124
+ """Parse all document elements with parallel decoding"""
125
+ layout_results = parse_layout_string(layout_results)
126
+
127
+ # Store text and table elements separately
128
+ text_elements = [] # Text elements
129
+ table_elements = [] # Table elements
130
+ figure_results = [] # Image elements (no processing needed)
131
+ previous_box = None
132
+ reading_order = 0
133
+
134
+ # Collect elements to process and group by type
135
+ for bbox, label in layout_results:
136
+ try:
137
+ # Adjust coordinates
138
+ x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = process_coordinates(
139
+ bbox, padded_image, dims, previous_box
140
+ )
141
+
142
+ # Crop and parse element
143
+ cropped = padded_image[y1:y2, x1:x2]
144
+ if cropped.size > 0:
145
+ if label == "fig":
146
+ # For figure regions, add empty text result immediately
147
+ figure_results.append(
148
+ {
149
+ "label": label,
150
+ "bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
151
+ "text": "",
152
+ "reading_order": reading_order,
153
+ }
154
+ )
155
+ else:
156
+ # Prepare element for parsing
157
+ pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
158
+ element_info = {
159
+ "crop": pil_crop,
160
+ "label": label,
161
+ "bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
162
+ "reading_order": reading_order,
163
+ }
164
+
165
+ # Group by type
166
+ if label == "tab":
167
+ table_elements.append(element_info)
168
+ else: # Text elements
169
+ text_elements.append(element_info)
170
+
171
+ reading_order += 1
172
+
173
+ except Exception as e:
174
+ print(f"Error processing bbox with label {label}: {str(e)}")
175
+ continue
176
+
177
+ # Initialize results list
178
+ recognition_results = figure_results.copy()
179
+
180
+ # Process text elements (in batches)
181
+ if text_elements:
182
+ text_results = process_element_batch(text_elements, model, "Read text in the image.", max_batch_size)
183
+ recognition_results.extend(text_results)
184
+
185
+ # Process table elements (in batches)
186
+ if table_elements:
187
+ table_results = process_element_batch(table_elements, model, "Parse the table in the image.", max_batch_size)
188
+ recognition_results.extend(table_results)
189
+
190
+ # Sort elements by reading order
191
+ recognition_results.sort(key=lambda x: x.get("reading_order", 0))
192
+
193
+ return recognition_results
194
+
195
+
196
+ def process_element_batch(elements, model, prompt, max_batch_size=None):
197
+ """Process elements of the same type in batches"""
198
+ results = []
199
+
200
+ # Determine batch size
201
+ batch_size = len(elements)
202
+ if max_batch_size is not None and max_batch_size > 0:
203
+ batch_size = min(batch_size, max_batch_size)
204
+
205
+ # Process in batches
206
+ for i in range(0, len(elements), batch_size):
207
+ batch_elements = elements[i:i+batch_size]
208
+ crops_list = [elem["crop"] for elem in batch_elements]
209
+
210
+ # Use the same prompt for all elements in the batch
211
+ prompts_list = [prompt] * len(crops_list)
212
+
213
+ # Batch inference
214
+ batch_results = model.chat(prompts_list, crops_list)
215
+
216
+ # Add results
217
+ for j, result in enumerate(batch_results):
218
+ elem = batch_elements[j]
219
+ results.append({
220
+ "label": elem["label"],
221
+ "bbox": elem["bbox"],
222
+ "text": result.strip(),
223
+ "reading_order": elem["reading_order"],
224
+ })
225
+
226
+ return results
227
+
228
+
229
+ def main():
230
+ parser = argparse.ArgumentParser(description="Document processing tool using DOLPHIN model")
231
+ parser.add_argument("--input_path", type=str, default="./demo", help="Path to input image or directory of images")
232
+ parser.add_argument(
233
+ "--save_dir",
234
+ type=str,
235
+ default=None,
236
+ help="Directory to save parsing results (default: same as input directory)",
237
+ )
238
+ parser.add_argument(
239
+ "--max_batch_size",
240
+ type=int,
241
+ default=16,
242
+ help="Maximum number of document elements to parse in a single batch (default: 16)",
243
+ )
244
+ args = parser.parse_args()
245
+
246
+ # Load Model
247
+ model = DOLPHIN("ByteDance/Dolphin")
248
+
249
+ # Collect Document Images
250
+ if os.path.isdir(args.input_path):
251
+ image_files = []
252
+ for ext in [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"]:
253
+ image_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}")))
254
+ image_files = sorted(image_files)
255
+ else:
256
+ if not os.path.exists(args.input_path):
257
+ raise FileNotFoundError(f"Input path {args.input_path} does not exist")
258
+ image_files = [args.input_path]
259
+
260
+ save_dir = args.save_dir or (
261
+ args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path)
262
+ )
263
+ setup_output_dirs(save_dir)
264
+
265
+ total_samples = len(image_files)
266
+ print(f"\nTotal samples to process: {total_samples}")
267
+
268
+ # Process All Document Images
269
+ for image_path in image_files:
270
+ print(f"\nProcessing {image_path}")
271
+ try:
272
+ json_path, recognition_results = process_page(
273
+ image_path=image_path,
274
+ model=model,
275
+ save_dir=save_dir,
276
+ max_batch_size=args.max_batch_size,
277
+ )
278
+
279
+ print(f"Processing completed. Results saved to {save_dir}")
280
+
281
+ except Exception as e:
282
+ print(f"Error processing {image_path}: {str(e)}")
283
+ continue
284
+
285
+
286
+ if __name__ == "__main__":
287
+ main()
pyproject.toml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.black]
2
+ line-length = 120
3
+ include = '\.pyi?$'
4
+ exclude = '''
5
+ /(
6
+ \.git
7
+ | \.hg
8
+ | \.mypy_cache
9
+ | \.tox
10
+ | \.venv
11
+ | _build
12
+ | buck-out
13
+ | build
14
+ | dist
15
+ )/
16
+ '''
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==5.24.0
2
+ gradio_pdf==0.0.22
3
+ pymupdf==1.25.5
4
+ loguru==0.7.3
5
+ torch==2.1.0
6
+ transformers==4.47.0
7
+ opencv-python==4.11.0.86
8
+ opencv-python-headless==4.5.5.64
9
+ Pillow==9.3.0
10
+ numpy==1.24.4
11
+ spaces
12
+ albumentations==1.4.0
13
+ requests==2.32.3
14
+ httpx==0.23.0
static/styles.css ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ :root {
2
+ /* 主色调 */
3
+ --primary-color: #dceaf6;
4
+ --primary-light: #f8f9fa;
5
+ --primary-dark: #9ec9e3;
6
+
7
+ /* 辅助色调 */
8
+ --accent-color: #bfe2f8;
9
+ --accent-light: #dceaf6;
10
+
11
+ /* 背景色 */
12
+ --bg-color: #e8eff5;
13
+ --card-bg: #ffffff;
14
+
15
+ /* 文本色 */
16
+ --dark-text: #2b2d42;
17
+ --light-text: #f8f9fa;
18
+ --muted-text: rgba(43, 45, 66, 0.7);
19
+
20
+ /* 边框和阴影 */
21
+ --border-color: rgba(168, 168, 168, 0.432);
22
+ --card-shadow: 0 4px 20px rgba(104, 104, 104, 0.1);
23
+
24
+ /* 交互状态 */
25
+ --hover-bg: rgba(255, 255, 255, 0.5);
26
+ --active-color: #bfe2f8;
27
+ }
28
+
29
+ body {
30
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
31
+ background-color: var(--bg-color);
32
+ }
33
+
34
+ /* 卡片样式 */
35
+ .gradio-container {
36
+ max-width: 95% !important;
37
+ width: 95% !important;
38
+ margin-left: auto !important;
39
+ margin-right: auto !important;
40
+ }
41
+
42
+ /* 面板样式 */
43
+ .panel {
44
+ border-radius: 12px !important;
45
+ border: 1px solid var(--border-color) !important;
46
+ box-shadow: var(--card-shadow) !important;
47
+ background-color: var(--card-bg) !important;
48
+ padding: 1.5rem !important;
49
+ }
50
+
51
+ /* 按钮样式 */
52
+ button.primary {
53
+ border-radius: 8px !important;
54
+ }
55
+
56
+ button {
57
+ border-radius: 8px !important;
58
+ border: 1px solid var(--border-color) !important;
59
+ background-color: var(--hover-bg) !important;
60
+ color: var(--dark-text) !important;
61
+ transition: all 0.3s ease !important;
62
+ }
63
+
64
+ button:hover {
65
+ transform: translateY(-2px) !important;
66
+ box-shadow: 0 5px 15px rgba(0, 0, 0, 0.1) !important;
67
+ background-color: var(--hover-bg) !important;
68
+ }
69
+
70
+ /* 文件上传区域 */
71
+ .file-preview {
72
+ border-radius: 8px !important;
73
+ border: 1px dashed var(--border-color) !important;
74
+ }
75
+
76
+ .file-preview:hover {
77
+ border-color: var(--primary-dark) !important;
78
+ }
79
+
80
+ /* 确保所有链接按钮样式正确 */
81
+ .header-buttons a,
82
+ .header-buttons a:hover,
83
+ .header-buttons a:visited,
84
+ .header-buttons a:active {
85
+ text-decoration: none !important;
86
+ color: var(--dark-text) !important;
87
+ }
88
+
89
+ /* 覆盖任何可能的内联样式 */
90
+ .header-buttons a[style] {
91
+ text-decoration: none !important;
92
+ color: var(--dark-text) !important;
93
+ }
94
+
95
+ /* 确保链接内的所有元素都没有下划线 */
96
+ .header-buttons a *,
97
+ .header-buttons a:hover * {
98
+ text-decoration: none !important;
99
+ }
100
+
101
+ /* 隐藏页面底部信息 */
102
+ footer, .footer, .footer-links, .gradio-footer {
103
+ display: none !important;
104
+ }
105
+
106
+ /* 隐藏底部工具栏 */
107
+ .gradio-container > div:last-child {
108
+ display: none !important;
109
+ }
110
+
111
+ /* 隐藏底部API按钮和设置按钮 */
112
+ .fixed-bottom {
113
+ display: none !important;
114
+ }
115
+
116
+ /* 隐藏Gradio品牌信息 */
117
+ .gr-prose p:last-child {
118
+ display: none !important;
119
+ }
120
+
121
+ /* 隐藏底部的所有可能元素 */
122
+ [class*="footer"], [id*="footer"], [class*="bottom-bar"], [id*="bottom-bar"] {
123
+ display: none !important;
124
+ }
125
+
126
+ /* 侧边栏样式 */
127
+ .sidebar {
128
+ background-color: var(--card-bg);
129
+ border-radius: 12px;
130
+ border: 1px solid var(--border-color);
131
+ box-shadow: var(--card-shadow);
132
+ padding: 1rem;
133
+ margin-right: 1rem;
134
+ }
135
+
136
+ /* 上传按钮样式 */
137
+ .upload-button {
138
+ display: flex;
139
+ align-items: center;
140
+ justify-content: center;
141
+ border: 2px dashed var(--border-color);
142
+ padding: 1rem;
143
+ margin-bottom: 1rem;
144
+ cursor: pointer;
145
+ transition: all 0.3s ease;
146
+ }
147
+
148
+ .upload-button:hover {
149
+ border-color: var(--primary-dark);
150
+ background-color: rgba(158, 201, 227, 0.1);
151
+ }
152
+
153
+ .upload-button i {
154
+ font-size: 1.5rem;
155
+ color: var(--primary-dark);
156
+ margin-right: 0.5rem;
157
+ }
158
+
159
+ /* 示例文件列表样式 */
160
+ .example-list {
161
+ list-style-type: none;
162
+ padding: 0;
163
+ margin: 0;
164
+ }
165
+
166
+ .example-item {
167
+ display: flex;
168
+ align-items: center;
169
+ padding: 0.5rem;
170
+ border-radius: 8px;
171
+ margin-bottom: 0.5rem;
172
+ cursor: pointer;
173
+ transition: all 0.3s ease;
174
+ }
175
+
176
+ .example-item:hover {
177
+ background-color: rgba(158, 201, 227, 0.1);
178
+ }
179
+
180
+ .example-item i {
181
+ font-size: 1.2rem;
182
+ color: var(--primary-dark);
183
+ margin-right: 0.5rem;
184
+ }
185
+
186
+ .example-item-name {
187
+ white-space: nowrap;
188
+ overflow: hidden;
189
+ text-overflow: ellipsis;
190
+ }
191
+
192
+ /* 取消和确认按钮样式 */
193
+ .action-buttons {
194
+ display: flex;
195
+ justify-content: flex-end;
196
+ }
197
+
198
+ /* 取消按钮样式 */
199
+ button[value="清空/Clear"] {
200
+ color: #e74c3c !important;
201
+ }
202
+
203
+ /* 隐藏原始文件上传组件 */
204
+ .file-upload {
205
+ display: none !important;
206
+ }
207
+
208
+ /* 主体内容样式 */
209
+ .main-content {
210
+ display: flex;
211
+ flex: 1;
212
+ }
213
+
214
+ /* 预览框样式 */
215
+ .preview-panel {
216
+ flex: 1;
217
+ background-color: var(--card-bg);
218
+ border-radius: 12px;
219
+ border: 1px solid var(--border-color);
220
+ box-shadow: var(--card-shadow);
221
+ padding: 1rem;
222
+ margin-right: 1rem;
223
+ }
224
+
225
+ /* 输出框样式 */
226
+ .output-panel {
227
+ flex: 1;
228
+ background-color: var(--card-bg);
229
+ border-radius: 12px;
230
+ border: 1px solid var(--border-color);
231
+ box-shadow: var(--card-shadow);
232
+ padding: 1rem;
233
+ }
234
+
235
+ /* 响应式布局 */
236
+ @media (max-width: 768px) {
237
+ .main-content {
238
+ flex-direction: column;
239
+ }
240
+
241
+ .sidebar, .preview-panel, .output-panel {
242
+ margin-right: 0;
243
+ margin-bottom: 1rem;
244
+ width: 100%;
245
+ }
246
+ }
247
+
248
+ /* 美化文件上传组件 */
249
+ #file-upload {
250
+ margin-bottom: 1.5rem;
251
+ }
252
+
253
+ #file-upload .file-preview {
254
+ border: 2px dashed var(--border-color);
255
+ padding: 1.5rem;
256
+ transition: all 0.3s ease;
257
+ text-align: center;
258
+ }
259
+
260
+ #file-upload .file-preview:hover {
261
+ border-color: var(--primary-dark);
262
+ background-color: rgba(158, 201, 227, 0.1);
263
+ }
264
+
265
+ /* 隐藏原始标签 */
266
+ #file-upload .label-wrap {
267
+ display: none;
268
+ }
269
+
270
+ /* 美化示例文件列表 */
271
+ #example-files .gr-samples-table {
272
+ border: none;
273
+ background: transparent;
274
+ }
275
+
276
+ #example-files .gr-samples-table td {
277
+ border: none;
278
+ padding: 0.5rem;
279
+ transition: all 0.3s ease;
280
+ border-radius: 8px;
281
+ }
282
+
283
+ #example-files .gr-samples-table tr:hover td {
284
+ background-color: rgba(158, 201, 227, 0.1);
285
+ }
286
+
287
+ #example-files .gr-samples-table td a {
288
+ display: flex;
289
+ align-items: center;
290
+ color: var(--dark-text);
291
+ text-decoration: none;
292
+ }
293
+
294
+ #example-files .gr-samples-table td a::before {
295
+ content: "\f1c1";
296
+ font-family: "Font Awesome 6 Free";
297
+ font-weight: 900;
298
+ margin-right: 0.5rem;
299
+ color: var(--primary-dark);
300
+ font-size: 1.2rem;
301
+ }
302
+
303
+ /* 隐藏分页控件 */
304
+ #example-files .gr-samples-pagination {
305
+ display: none;
306
+ }
utils/markdown_utils.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
3
+ SPDX-License-Identifier: MIT
4
+ """
5
+
6
+ import re
7
+ import base64
8
+ from typing import List, Dict, Any, Optional
9
+
10
+
11
+ """
12
+ Example input:
13
+ [
14
+ {"label": "tab", "bbox": [0.176, 0.74, 0.824, 0.82], "text": "<table><tr><td></td><td>HellaSwag</td><td>Obqa</td><td>WinoGrande</td><td>ARC-c</td><td>ARC-e</td><td>boolq</td><td>piqa</td><td>Avg</td></tr><tr><td>OPT-1.3B</td><td>53.65</td><td>33.40</td><td>59.59</td><td>29.44</td><td>50.80</td><td>60.83</td><td>72.36</td><td>51.44</td></tr><tr><td>Pythia-1.0B</td><td>47.16</td><td>31.40</td><td>53.43</td><td>27.05</td><td>48.99</td><td>57.83</td><td>69.21</td><td>48.30</td></tr><tr><td>Pythia-1.4B</td><td>52.01</td><td>33.20</td><td>57.38</td><td>28.50</td><td>54.00</td><td>63.27</td><td>70.95</td><td>51.33</td></tr><tr><td>TinyLlama-1.1B</td><td>59.20</td><td>36.00</td><td>59.12</td><td>30.10</td><td>55.25</td><td>57.83</td><td>73.29</td><td>52.99</td></tr></table>", "reading_order": 6},
15
+ {"label": "cap", "bbox": [0.28, 0.729, 0.711, 0.74], "text": "Table 2: Zero-shot performance on commonsense reasoning tasks", "reading_order": 7},
16
+ {"label": "para", "bbox": [0.176, 0.848, 0.826, 0.873], "text": "We of performance during training We tracked the accuracy of TinyLlama on common-\nsense reasoning benchmarks during its pre-training, as shown in Fig. 2 . Generally, the performance of", "reading_order": 8},
17
+ {"label": "fnote", "bbox": [0.176, 0.88, 0.824, 0.912], "text": "${ }^{4}$ Due to a bug in the config file, the learning rate did not decrease immediately after warmup and remained at\nthe maximum value for several steps before we fixed this.", "reading_order": 9},
18
+ {"label": "foot", "bbox": [0.496, 0.939, 0.501, 0.95], "text": "14", "reading_order": 10}
19
+ ]
20
+ """
21
+
22
+
23
+ def extract_table_from_html(html_string):
24
+ """Extract and clean table tags from HTML string"""
25
+ try:
26
+ table_pattern = re.compile(r'<table.*?>.*?</table>', re.DOTALL)
27
+ tables = table_pattern.findall(html_string)
28
+ tables = [re.sub(r'<table[^>]*>', '<table>', table) for table in tables]
29
+ return '\n'.join(tables)
30
+ except Exception as e:
31
+ print(f"extract_table_from_html error: {str(e)}")
32
+ return f"<table><tr><td>Error extracting table: {str(e)}</td></tr></table>"
33
+
34
+
35
+ class MarkdownConverter:
36
+ """Convert structured recognition results to Markdown format"""
37
+
38
+ def __init__(self):
39
+ # Define heading levels for different section types
40
+ self.heading_levels = {
41
+ 'title': '#',
42
+ 'sec': '##',
43
+ 'sub_sec': '###'
44
+ }
45
+
46
+ # Define which labels need special handling
47
+ self.special_labels = {
48
+ 'tab', 'fig', 'title', 'sec', 'sub_sec',
49
+ 'list', 'formula', 'reference', 'alg'
50
+ }
51
+
52
+ def try_remove_newline(self, text: str) -> str:
53
+ try:
54
+ # Preprocess text to handle line breaks
55
+ text = text.strip()
56
+ text = text.replace('-\n', '')
57
+
58
+ # Handle Chinese text line breaks
59
+ def is_chinese(char):
60
+ return '\u4e00' <= char <= '\u9fff'
61
+
62
+ lines = text.split('\n')
63
+ processed_lines = []
64
+
65
+ # Process all lines except the last one
66
+ for i in range(len(lines)-1):
67
+ current_line = lines[i].strip()
68
+ next_line = lines[i+1].strip()
69
+
70
+ # Always add the current line, but determine if we need a newline
71
+ if current_line: # If current line is not empty
72
+ if next_line: # If next line is not empty
73
+ # For Chinese text handling
74
+ if is_chinese(current_line[-1]) and is_chinese(next_line[0]):
75
+ processed_lines.append(current_line)
76
+ else:
77
+ processed_lines.append(current_line + ' ')
78
+ else:
79
+ # Next line is empty, add current line with newline
80
+ processed_lines.append(current_line + '\n')
81
+ else:
82
+ # Current line is empty, add an empty line
83
+ processed_lines.append('\n')
84
+
85
+ # Add the last line
86
+ if lines and lines[-1].strip():
87
+ processed_lines.append(lines[-1].strip())
88
+
89
+ text = ''.join(processed_lines)
90
+
91
+ return text
92
+ except Exception as e:
93
+ print(f"try_remove_newline error: {str(e)}")
94
+ return text # Return original text on error
95
+
96
+ def _handle_text(self, text: str) -> str:
97
+ """
98
+ Process regular text content, preserving paragraph structure
99
+ """
100
+ try:
101
+ if not text:
102
+ return ""
103
+
104
+ if text.strip().startswith("\\begin{array}") and text.strip().endswith("\\end{array}"):
105
+ text = "$$" + text + "$$"
106
+ elif ("_{" in text or "^{" in text or "\\" in text or "_ {" in text or "^ {" in text) and ("$" not in text) and ("\\begin" not in text):
107
+ text = "$" + text + "$"
108
+
109
+ # Process formulas in text before handling other text processing
110
+ text = self._process_formulas_in_text(text)
111
+
112
+ text = self.try_remove_newline(text)
113
+
114
+ # Return processed text
115
+ return text
116
+ except Exception as e:
117
+ print(f"_handle_text error: {str(e)}")
118
+ return text # Return original text on error
119
+
120
+ def _process_formulas_in_text(self, text: str) -> str:
121
+ """
122
+ Process mathematical formulas in text by iteratively finding and replacing formulas.
123
+ - Identify inline and block formulas
124
+ - Replace newlines within formulas with \\
125
+ """
126
+ try:
127
+ # Define formula delimiters and their corresponding patterns
128
+ delimiters = [
129
+ ('$$', '$$'), # Block formula with $$
130
+ ('\\[', '\\]'), # Block formula with \[ \]
131
+ ('$', '$'), # Inline formula with $
132
+ ('\\(', '\\)') # Inline formula with \( \)
133
+ ]
134
+
135
+ # Process the text by iterating through each delimiter type
136
+ result = text
137
+
138
+ for start_delim, end_delim in delimiters:
139
+ # Create a pattern that matches from start to end delimiter
140
+ # Using a custom approach to avoid issues with nested delimiters
141
+ current_pos = 0
142
+ processed_parts = []
143
+
144
+ while current_pos < len(result):
145
+ # Find the next start delimiter
146
+ start_pos = result.find(start_delim, current_pos)
147
+ if start_pos == -1:
148
+ # No more formulas of this type
149
+ processed_parts.append(result[current_pos:])
150
+ break
151
+
152
+ # Add text before the formula
153
+ processed_parts.append(result[current_pos:start_pos])
154
+
155
+ # Find the matching end delimiter
156
+ end_pos = result.find(end_delim, start_pos + len(start_delim))
157
+ if end_pos == -1:
158
+ # No matching end delimiter, treat as regular text
159
+ processed_parts.append(result[start_pos:])
160
+ break
161
+
162
+ # Extract the formula content (without delimiters)
163
+ formula_content = result[start_pos + len(start_delim):end_pos]
164
+
165
+ # Process the formula content - replace newlines with \\
166
+ processed_formula = formula_content.replace('\n', ' \\\\ ')
167
+
168
+ # Add the processed formula with its delimiters
169
+ processed_parts.append(f"{start_delim}{processed_formula}{end_delim}")
170
+
171
+ # Move past this formula
172
+ current_pos = end_pos + len(end_delim)
173
+
174
+ # Update the result with processed text
175
+ result = ''.join(processed_parts)
176
+ return result
177
+ except Exception as e:
178
+ print(f"_process_formulas_in_text error: {str(e)}")
179
+ return text # Return original text on error
180
+
181
+ def _remove_newline_in_heading(self, text: str) -> str:
182
+ """
183
+ Remove newline in heading
184
+ """
185
+ try:
186
+ # Handle Chinese text line breaks
187
+ def is_chinese(char):
188
+ return '\u4e00' <= char <= '\u9fff'
189
+
190
+ # Check if the text contains Chinese characters
191
+ if any(is_chinese(char) for char in text):
192
+ return text.replace('\n', '')
193
+ else:
194
+ return text.replace('\n', ' ')
195
+
196
+ except Exception as e:
197
+ print(f"_remove_newline_in_heading error: {str(e)}")
198
+ return text
199
+
200
+ def _handle_heading(self, text: str, label: str) -> str:
201
+ """
202
+ Convert section headings to appropriate markdown format
203
+ """
204
+ try:
205
+ level = self.heading_levels.get(label, '#')
206
+ text = text.strip()
207
+ text = self._remove_newline_in_heading(text)
208
+ text = self._handle_text(text)
209
+ return f"{level} {text}\n\n"
210
+ except Exception as e:
211
+ print(f"_handle_heading error: {str(e)}")
212
+ return f"# Error processing heading: {text}\n\n"
213
+
214
+ def _handle_list_item(self, text: str) -> str:
215
+ """
216
+ Convert list items to markdown list format
217
+ """
218
+ try:
219
+ return f"- {text.strip()}\n"
220
+ except Exception as e:
221
+ print(f"_handle_list_item error: {str(e)}")
222
+ return f"- Error processing list item: {text}\n"
223
+
224
+ def _handle_figure(self, text: str, section_count: int) -> str:
225
+ """
226
+ Convert base64 encoded image to markdown image syntax
227
+ """
228
+ try:
229
+ # Determine image format (assuming PNG if not specified)
230
+ img_format = "png"
231
+ if text.startswith("data:image/"):
232
+ # Extract format from data URI
233
+ img_format = text.split(";")[0].split("/")[1]
234
+ elif ";" in text and "," in text:
235
+ # Already in data URI format
236
+ return f"![Figure {section_count}]({text})\n\n"
237
+ else:
238
+ # Raw base64, convert to data URI
239
+ data_uri = f"data:image/{img_format};base64,{text}"
240
+ return f"![Figure {section_count}]({data_uri})\n\n"
241
+ except Exception as e:
242
+ print(f"_handle_figure error: {str(e)}")
243
+ return f"*[Error processing figure: {str(e)}]*\n\n"
244
+
245
+ def _handle_table(self, text: str) -> str:
246
+ """
247
+ Convert table content to markdown format
248
+ """
249
+ try:
250
+ markdown_content = []
251
+ if '<table' in text.lower() or '<tr' in text.lower():
252
+ markdown_table = extract_table_from_html(text)
253
+ markdown_content.append(markdown_table + "\n")
254
+ else:
255
+ table_lines = text.split('\n')
256
+ if table_lines:
257
+ col_count = len(table_lines[0].split()) if table_lines[0] else 1
258
+ header = '| ' + ' | '.join(table_lines[0].split()) + ' |'
259
+ markdown_content.append(header)
260
+ markdown_content.append('| ' + ' | '.join(['---'] * col_count) + ' |')
261
+ for line in table_lines[1:]:
262
+ cells = line.split()
263
+ while len(cells) < col_count:
264
+ cells.append('')
265
+ markdown_content.append('| ' + ' | '.join(cells) + ' |')
266
+ return '\n'.join(markdown_content) + '\n\n'
267
+ except Exception as e:
268
+ print(f"_handle_table error: {str(e)}")
269
+ return f"*[Error processing table: {str(e)}]*\n\n"
270
+
271
+ def _handle_algorithm(self, text: str) -> str:
272
+ """
273
+ Process algorithm blocks with proper formatting
274
+ """
275
+ try:
276
+ # Remove algorithm environment tags if present
277
+ text = re.sub(r'\\begin\{algorithm\}(.*?)\\end\{algorithm\}', r'\1', text, flags=re.DOTALL)
278
+ text = text.replace('\\begin{algorithm}', '').replace('\\end{algorithm}', '')
279
+ text = text.replace('\\begin{algorithmic}', '').replace('\\end{algorithmic}', '')
280
+
281
+ # Process the algorithm text
282
+ lines = text.strip().split('\n')
283
+
284
+ # Check if there's a caption or label
285
+ caption = ""
286
+ algorithm_text = []
287
+
288
+ for line in lines:
289
+ if '\\caption' in line:
290
+ # Extract caption text
291
+ caption_match = re.search(r'\\caption\{(.*?)\}', line)
292
+ if caption_match:
293
+ caption = f"**{caption_match.group(1)}**\n\n"
294
+ continue
295
+ elif '\\label' in line:
296
+ continue # Skip label lines
297
+ else:
298
+ algorithm_text.append(line)
299
+
300
+ # Join the algorithm text and wrap in code block
301
+ formatted_text = '\n'.join(algorithm_text)
302
+
303
+ # Return the formatted algorithm with caption
304
+ return f"{caption}```\n{formatted_text}\n```\n\n"
305
+ except Exception as e:
306
+ print(f"_handle_algorithm error: {str(e)}")
307
+ return f"*[Error processing algorithm: {str(e)}]*\n\n{text}\n\n"
308
+
309
+ def _handle_formula(self, text: str) -> str:
310
+ """
311
+ Handle formula-specific content
312
+ """
313
+ try:
314
+ # Process the formula content
315
+ processed_text = self._process_formulas_in_text(text)
316
+
317
+ # For formula blocks, ensure they're properly formatted in markdown
318
+ if '$$' not in processed_text and '\\[' not in processed_text:
319
+ # If no block formula delimiters are present, wrap in $$ for block formula
320
+ processed_text = f'$${processed_text}$$'
321
+
322
+ return f"{processed_text}\n\n"
323
+ except Exception as e:
324
+ print(f"_handle_formula error: {str(e)}")
325
+ return f"*[Error processing formula: {str(e)}]*\n\n"
326
+
327
+ def convert(self, recognition_results: List[Dict[str, Any]]) -> str:
328
+ """
329
+ Convert recognition results to markdown format
330
+ """
331
+ try:
332
+ markdown_content = []
333
+
334
+ for section_count, result in enumerate(recognition_results):
335
+ try:
336
+ label = result.get('label', '')
337
+ text = result.get('text', '').strip()
338
+
339
+ # Skip empty text
340
+ if not text:
341
+ continue
342
+
343
+ # Handle different content types
344
+ if label in {'title', 'sec', 'sub_sec'}:
345
+ markdown_content.append(self._handle_heading(text, label))
346
+ elif label == 'list':
347
+ markdown_content.append(self._handle_list_item(text))
348
+ elif label == 'fig':
349
+ markdown_content.append(self._handle_figure(text, section_count))
350
+ elif label == 'tab':
351
+ markdown_content.append(self._handle_table(text))
352
+ elif label == 'alg':
353
+ markdown_content.append(self._handle_algorithm(text))
354
+ elif label == 'formula':
355
+ markdown_content.append(self._handle_formula(text))
356
+ elif label not in self.special_labels:
357
+ # Handle regular text (paragraphs, etc.)
358
+ processed_text = self._handle_text(text)
359
+ markdown_content.append(f"{processed_text}\n\n")
360
+ except Exception as e:
361
+ print(f"Error processing item {section_count}: {str(e)}")
362
+ # Add a placeholder for the failed item
363
+ markdown_content.append(f"*[Error processing content]*\n\n")
364
+
365
+ # Join all content and apply post-processing
366
+ result = ''.join(markdown_content)
367
+ return self._post_process(result)
368
+ except Exception as e:
369
+ print(f"convert error: {str(e)}")
370
+ return f"Error generating markdown content: {str(e)}"
371
+
372
+ def _post_process(self, markdown_content: str) -> str:
373
+ """
374
+ Apply post-processing fixes to the generated markdown content
375
+ """
376
+ try:
377
+ # Handle author information
378
+ author_pattern = re.compile(r'\\author\{(.*?)\}', re.DOTALL)
379
+
380
+ def process_author_match(match):
381
+ # Extract author content
382
+ author_content = match.group(1)
383
+ # Process the author content
384
+ return self._handle_text(author_content)
385
+
386
+ # Replace \author{...} with processed content
387
+ markdown_content = author_pattern.sub(process_author_match, markdown_content)
388
+
389
+ # Handle special case where author is inside math environment
390
+ math_author_pattern = re.compile(r'\$(\\author\{.*?\})\$', re.DOTALL)
391
+ match = math_author_pattern.search(markdown_content)
392
+ if match:
393
+ # Extract the author command
394
+ author_cmd = match.group(1)
395
+ # Extract content from author command
396
+ author_content_match = re.search(r'\\author\{(.*?)\}', author_cmd, re.DOTALL)
397
+ if author_content_match:
398
+ # Get author content and process it
399
+ author_content = author_content_match.group(1)
400
+ processed_content = self._handle_text(author_content)
401
+ # Replace the entire $\author{...}$ block with processed content
402
+ markdown_content = markdown_content.replace(match.group(0), processed_content)
403
+
404
+ # Replace LaTeX abstract environment with plain text
405
+ markdown_content = re.sub(r'\\begin\{abstract\}(.*?)\\end\{abstract\}',
406
+ r'**Abstract** \1',
407
+ markdown_content,
408
+ flags=re.DOTALL)
409
+
410
+ # Replace standalone \begin{abstract} (without matching end)
411
+ markdown_content = re.sub(r'\\begin\{abstract\}',
412
+ r'**Abstract**',
413
+ markdown_content)
414
+
415
+ # Replace LaTeX equation numbers with tag format, handling cases with extra backslashes
416
+ markdown_content = re.sub(r'\\eqno\{\((.*?)\)\}',
417
+ r'\\tag{\1}',
418
+ markdown_content)
419
+
420
+ # Find the starting tag of the formula
421
+ markdown_content = markdown_content.replace("\[ \\\\", "$$ \\\\")
422
+
423
+ # Find the ending tag of the formula (ensure this is the only ending tag)
424
+ markdown_content = markdown_content.replace("\\\\ \]", "\\\\ $$")
425
+
426
+ # Fix other common LaTeX issues
427
+ replacements = [
428
+ # Fix spacing issues in subscripts and superscripts
429
+ (r'_ {', r'_{'),
430
+ (r'^ {', r'^{'),
431
+
432
+ # Fix potential issues with multiple consecutive newlines
433
+ (r'\n{3,}', r'\n\n')
434
+ ]
435
+
436
+ for old, new in replacements:
437
+ markdown_content = re.sub(old, new, markdown_content)
438
+
439
+ return markdown_content
440
+ except Exception as e:
441
+ print(f"_post_process error: {str(e)}")
442
+ return markdown_content # Return original content if post-processing fails
utils/utils.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
3
+ SPDX-License-Identifier: MIT
4
+ """
5
+
6
+ import copy
7
+ import json
8
+ import os
9
+ import re
10
+ from dataclasses import dataclass
11
+ from typing import List, Tuple
12
+
13
+ import albumentations as alb
14
+ import cv2
15
+ import numpy as np
16
+ from albumentations.pytorch import ToTensorV2
17
+ from PIL import Image
18
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
19
+ from torchvision.transforms.functional import resize
20
+
21
+ from utils.markdown_utils import MarkdownConverter
22
+
23
+
24
+ def alb_wrapper(transform):
25
+ def f(im):
26
+ return transform(image=np.asarray(im))["image"]
27
+
28
+ return f
29
+
30
+
31
+ test_transform = alb_wrapper(
32
+ alb.Compose(
33
+ [
34
+ alb.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
35
+ ToTensorV2(),
36
+ ]
37
+ )
38
+ )
39
+
40
+
41
+ def check_coord_valid(x1, y1, x2, y2, image_size=None, abs_coord=True):
42
+ # print(f"check_coord_valid: {x1}, {y1}, {x2}, {y2}, {image_size}, {abs_coord}")
43
+ if x2 <= x1 or y2 <= y1:
44
+ return False, f"[{x1}, {y1}, {x2}, {y2}]"
45
+ if x1 < 0 or y1 < 0:
46
+ return False, f"[{x1}, {y1}, {x2}, {y2}]"
47
+ if not abs_coord:
48
+ if x2 > 1 or y2 > 1:
49
+ return False, f"[{x1}, {y1}, {x2}, {y2}]"
50
+ elif image_size is not None: # has image size
51
+ if x2 > image_size[0] or y2 > image_size[1]:
52
+ return False, f"[{x1}, {y1}, {x2}, {y2}]"
53
+ return True, None
54
+
55
+
56
+ def adjust_box_edges(image, boxes: List[List[float]], max_pixels=15, threshold=0.2):
57
+ """
58
+ Image: cv2.image object, or Path
59
+ Input: boxes: list of boxes [[x1, y1, x2, y2]]. Using absolute coordinates.
60
+ """
61
+ if isinstance(image, str):
62
+ image = cv2.imread(image)
63
+ img_h, img_w = image.shape[:2]
64
+ new_boxes = []
65
+ for box in boxes:
66
+ best_box = copy.deepcopy(box)
67
+
68
+ def check_edge(img, current_box, i, is_vertical):
69
+ edge = current_box[i]
70
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
71
+ _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
72
+
73
+ if is_vertical:
74
+ line = binary[current_box[1] : current_box[3] + 1, edge]
75
+ else:
76
+ line = binary[edge, current_box[0] : current_box[2] + 1]
77
+
78
+ transitions = np.abs(np.diff(line))
79
+ return np.sum(transitions) / len(transitions)
80
+
81
+ # Only widen the box
82
+ edges = [(0, -1, True), (2, 1, True), (1, -1, False), (3, 1, False)]
83
+
84
+ current_box = copy.deepcopy(box)
85
+ # make sure the box is within the image
86
+ current_box[0] = min(max(current_box[0], 0), img_w - 1)
87
+ current_box[1] = min(max(current_box[1], 0), img_h - 1)
88
+ current_box[2] = min(max(current_box[2], 0), img_w - 1)
89
+ current_box[3] = min(max(current_box[3], 0), img_h - 1)
90
+
91
+ for i, direction, is_vertical in edges:
92
+ best_score = check_edge(image, current_box, i, is_vertical)
93
+ if best_score <= threshold:
94
+ continue
95
+ for step in range(max_pixels):
96
+ current_box[i] += direction
97
+ if i == 0 or i == 2:
98
+ current_box[i] = min(max(current_box[i], 0), img_w - 1)
99
+ else:
100
+ current_box[i] = min(max(current_box[i], 0), img_h - 1)
101
+ score = check_edge(image, current_box, i, is_vertical)
102
+
103
+ if score < best_score:
104
+ best_score = score
105
+ best_box = copy.deepcopy(current_box)
106
+
107
+ if score <= threshold:
108
+ break
109
+ new_boxes.append(best_box)
110
+
111
+ return new_boxes
112
+
113
+
114
+ def parse_layout_string(bbox_str):
115
+ """Parse layout string using regular expressions"""
116
+ pattern = r"\[(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+)\]\s*(\w+)"
117
+ matches = re.finditer(pattern, bbox_str)
118
+
119
+ parsed_results = []
120
+ for match in matches:
121
+ coords = [float(match.group(i)) for i in range(1, 5)]
122
+ label = match.group(5).strip()
123
+ parsed_results.append((coords, label))
124
+
125
+ return parsed_results
126
+
127
+
128
+ @dataclass
129
+ class ImageDimensions:
130
+ """Class to store image dimensions"""
131
+ original_w: int
132
+ original_h: int
133
+ padded_w: int
134
+ padded_h: int
135
+
136
+
137
+ def map_to_original_coordinates(x1, y1, x2, y2, dims: ImageDimensions) -> Tuple[int, int, int, int]:
138
+ """Map coordinates from padded image back to original image
139
+
140
+ Args:
141
+ x1, y1, x2, y2: Coordinates in padded image
142
+ dims: Image dimensions object
143
+
144
+ Returns:
145
+ tuple: (x1, y1, x2, y2) coordinates in original image
146
+ """
147
+ try:
148
+ # Calculate padding offsets
149
+ top = (dims.padded_h - dims.original_h) // 2
150
+ left = (dims.padded_w - dims.original_w) // 2
151
+
152
+ # Map back to original coordinates
153
+ orig_x1 = max(0, x1 - left)
154
+ orig_y1 = max(0, y1 - top)
155
+ orig_x2 = min(dims.original_w, x2 - left)
156
+ orig_y2 = min(dims.original_h, y2 - top)
157
+
158
+ # Ensure we have a valid box (width and height > 0)
159
+ if orig_x2 <= orig_x1:
160
+ orig_x2 = min(orig_x1 + 1, dims.original_w)
161
+ if orig_y2 <= orig_y1:
162
+ orig_y2 = min(orig_y1 + 1, dims.original_h)
163
+
164
+ return int(orig_x1), int(orig_y1), int(orig_x2), int(orig_y2)
165
+ except Exception as e:
166
+ print(f"map_to_original_coordinates error: {str(e)}")
167
+ # Return safe coordinates
168
+ return 0, 0, min(100, dims.original_w), min(100, dims.original_h)
169
+
170
+
171
+ def map_to_relevant_coordinates(abs_coords, dims: ImageDimensions):
172
+ """
173
+ From absolute coordinates to relevant coordinates
174
+ e.g. [100, 100, 200, 200] -> [0.1, 0.2, 0.3, 0.4]
175
+ """
176
+ try:
177
+ x1, y1, x2, y2 = abs_coords
178
+ return round(x1 / dims.original_w, 3), round(y1 / dims.original_h, 3), round(x2 / dims.original_w, 3), round(y2 / dims.original_h, 3)
179
+ except Exception as e:
180
+ print(f"map_to_relevant_coordinates error: {str(e)}")
181
+ return 0.0, 0.0, 1.0, 1.0 # Return full image coordinates
182
+
183
+
184
+ def process_coordinates(coords, padded_image, dims: ImageDimensions, previous_box=None):
185
+ """Process and adjust coordinates
186
+
187
+ Args:
188
+ coords: Normalized coordinates [x1, y1, x2, y2]
189
+ padded_image: Padded image
190
+ dims: Image dimensions object
191
+ previous_box: Previous box coordinates for overlap adjustment
192
+
193
+ Returns:
194
+ tuple: (x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box)
195
+ """
196
+ try:
197
+ # Convert normalized coordinates to absolute coordinates
198
+ x1, y1 = int(coords[0] * dims.padded_w), int(coords[1] * dims.padded_h)
199
+ x2, y2 = int(coords[2] * dims.padded_w), int(coords[3] * dims.padded_h)
200
+
201
+ # Ensure coordinates are within image bounds before adjustment
202
+ x1 = max(0, min(x1, dims.padded_w - 1))
203
+ y1 = max(0, min(y1, dims.padded_h - 1))
204
+ x2 = max(0, min(x2, dims.padded_w))
205
+ y2 = max(0, min(y2, dims.padded_h))
206
+
207
+ # Ensure width and height are at least 1 pixel
208
+ if x2 <= x1:
209
+ x2 = min(x1 + 1, dims.padded_w)
210
+ if y2 <= y1:
211
+ y2 = min(y1 + 1, dims.padded_h)
212
+
213
+ # Extend box boundaries
214
+ new_boxes = adjust_box_edges(padded_image, [[x1, y1, x2, y2]])
215
+ x1, y1, x2, y2 = new_boxes[0]
216
+
217
+ # Ensure coordinates are still within image bounds after adjustment
218
+ x1 = max(0, min(x1, dims.padded_w - 1))
219
+ y1 = max(0, min(y1, dims.padded_h - 1))
220
+ x2 = max(0, min(x2, dims.padded_w))
221
+ y2 = max(0, min(y2, dims.padded_h))
222
+
223
+ # Ensure width and height are at least 1 pixel after adjustment
224
+ if x2 <= x1:
225
+ x2 = min(x1 + 1, dims.padded_w)
226
+ if y2 <= y1:
227
+ y2 = min(y1 + 1, dims.padded_h)
228
+
229
+ # Check for overlap with previous box and adjust
230
+ if previous_box is not None:
231
+ prev_x1, prev_y1, prev_x2, prev_y2 = previous_box
232
+ if (x1 < prev_x2 and x2 > prev_x1) and (y1 < prev_y2 and y2 > prev_y1):
233
+ y1 = prev_y2
234
+ # Ensure y1 is still valid
235
+ y1 = min(y1, dims.padded_h - 1)
236
+ # Make sure y2 is still greater than y1
237
+ if y2 <= y1:
238
+ y2 = min(y1 + 1, dims.padded_h)
239
+
240
+ # Update previous box
241
+ new_previous_box = [x1, y1, x2, y2]
242
+
243
+ # Map to original coordinates
244
+ orig_x1, orig_y1, orig_x2, orig_y2 = map_to_original_coordinates(
245
+ x1, y1, x2, y2, dims
246
+ )
247
+
248
+ return x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box
249
+ except Exception as e:
250
+ print(f"process_coordinates error: {str(e)}")
251
+ # Return safe values
252
+ orig_x1, orig_y1, orig_x2, orig_y2 = 0, 0, min(100, dims.original_w), min(100, dims.original_h)
253
+ return 0, 0, 100, 100, orig_x1, orig_y1, orig_x2, orig_y2, [0, 0, 100, 100]
254
+
255
+
256
+ def prepare_image(image) -> Tuple[np.ndarray, ImageDimensions]:
257
+ """Load and prepare image with padding while maintaining aspect ratio
258
+
259
+ Args:
260
+ image: PIL image
261
+
262
+ Returns:
263
+ tuple: (padded_image, image_dimensions)
264
+ """
265
+ try:
266
+ # Convert PIL image to OpenCV format
267
+ image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
268
+ original_h, original_w = image.shape[:2]
269
+
270
+ # Calculate padding to make square image
271
+ max_size = max(original_h, original_w)
272
+ top = (max_size - original_h) // 2
273
+ bottom = max_size - original_h - top
274
+ left = (max_size - original_w) // 2
275
+ right = max_size - original_w - left
276
+
277
+ # Apply padding
278
+ padded_image = cv2.copyMakeBorder(image, top, bottom, left, right,
279
+ cv2.BORDER_CONSTANT, value=(0, 0, 0))
280
+
281
+ padded_h, padded_w = padded_image.shape[:2]
282
+
283
+ dimensions = ImageDimensions(
284
+ original_w=original_w,
285
+ original_h=original_h,
286
+ padded_w=padded_w,
287
+ padded_h=padded_h
288
+ )
289
+
290
+ return padded_image, dimensions
291
+ except Exception as e:
292
+ print(f"prepare_image error: {str(e)}")
293
+ # Create a minimal valid image and dimensions
294
+ h, w = image.height, image.width
295
+ dimensions = ImageDimensions(
296
+ original_w=w,
297
+ original_h=h,
298
+ padded_w=w,
299
+ padded_h=h
300
+ )
301
+ # Return a black image of the same size
302
+ return np.zeros((h, w, 3), dtype=np.uint8), dimensions
303
+
304
+
305
+
306
+
307
+ def setup_output_dirs(save_dir):
308
+ """Create necessary output directories"""
309
+ os.makedirs(save_dir, exist_ok=True)
310
+ os.makedirs(os.path.join(save_dir, "markdown"), exist_ok=True)
311
+ os.makedirs(os.path.join(save_dir, "recognition_json"), exist_ok=True)
312
+
313
+
314
+ def save_outputs(recognition_results, image_path, save_dir):
315
+ """Save JSON and markdown outputs"""
316
+ basename = os.path.splitext(os.path.basename(image_path))[0]
317
+
318
+ # Save JSON file
319
+ json_path = os.path.join(save_dir, "recognition_json", f"{basename}.json")
320
+ with open(json_path, "w", encoding="utf-8") as f:
321
+ json.dump(recognition_results, f, ensure_ascii=False, indent=2)
322
+
323
+ # Generate and save markdown file
324
+ markdown_converter = MarkdownConverter()
325
+ markdown_content = markdown_converter.convert(recognition_results)
326
+ markdown_path = os.path.join(save_dir, "markdown", f"{basename}.md")
327
+ with open(markdown_path, "w", encoding="utf-8") as f:
328
+ f.write(markdown_content)
329
+
330
+ return json_path
331
+
332
+
333
+ def crop_margin(img: Image.Image) -> Image.Image:
334
+ """Crop margins from image"""
335
+ try:
336
+ width, height = img.size
337
+ if width == 0 or height == 0:
338
+ print("Warning: Image has zero width or height")
339
+ return img
340
+
341
+ data = np.array(img.convert("L"))
342
+ data = data.astype(np.uint8)
343
+ max_val = data.max()
344
+ min_val = data.min()
345
+ if max_val == min_val:
346
+ return img
347
+ data = (data - min_val) / (max_val - min_val) * 255
348
+ gray = 255 * (data < 200).astype(np.uint8)
349
+
350
+ coords = cv2.findNonZero(gray) # Find all non-zero points (text)
351
+ if coords is None:
352
+ return img
353
+ a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
354
+
355
+ # Ensure crop coordinates are within image bounds
356
+ a = max(0, a)
357
+ b = max(0, b)
358
+ w = min(w, width - a)
359
+ h = min(h, height - b)
360
+
361
+ # Only crop if we have a valid region
362
+ if w > 0 and h > 0:
363
+ return img.crop((a, b, a + w, b + h))
364
+ return img
365
+ except Exception as e:
366
+ print(f"crop_margin error: {str(e)}")
367
+ return img # Return original image on error