code27panda commited on
Commit
80e8620
·
verified ·
1 Parent(s): 2bd090b

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +574 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import time
4
+ import torch
5
+ import tempfile
6
+ from PIL import Image
7
+ from dotenv import load_dotenv
8
+ import logging
9
+ from datetime import datetime
10
+
11
+ # Set up logging
12
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Load environment variables
16
+ load_dotenv()
17
+ HF_TOKEN = os.getenv("HF_TOKEN")
18
+ CACHE_DIR = os.getenv("CACHE_DIR", os.path.join(tempfile.gettempdir(), "smoldocling_cache"))
19
+
20
+ # Ensure cache directory exists
21
+ os.makedirs(CACHE_DIR, exist_ok=True)
22
+
23
+ # Import for Transformers approach
24
+ try:
25
+ from transformers import AutoProcessor, AutoModelForVision2Seq
26
+ from huggingface_hub import login
27
+ transformers_available = True
28
+ except ImportError:
29
+ transformers_available = False
30
+
31
+ try:
32
+ from docling_core.types.doc import DoclingDocument
33
+ from docling_core.types.doc.document import DocTagsDocument
34
+ docling_available = True
35
+ except ImportError:
36
+ docling_available = False
37
+
38
+ # Global variables for model caching
39
+ processor = None
40
+ model = None
41
+
42
+ def check_dependencies():
43
+ """Check if all required dependencies are installed"""
44
+ missing = []
45
+ if not transformers_available:
46
+ missing.append("transformers huggingface_hub")
47
+ if not docling_available:
48
+ missing.append("docling-core")
49
+
50
+ return missing
51
+
52
+ def get_available_devices():
53
+ """Get available processing devices"""
54
+ devices = ["cpu"]
55
+ if torch.cuda.is_available():
56
+ cuda_count = torch.cuda.device_count()
57
+ for i in range(cuda_count):
58
+ devices.append(f"cuda:{i} ({torch.cuda.get_device_name(i)})")
59
+ return devices
60
+
61
+ def get_device_from_selection(selection):
62
+ """Convert user-friendly device selection to torch device"""
63
+ if selection.startswith("cuda:"):
64
+ return selection.split(" ")[0] # Extract just the "cuda:X" part
65
+ return "cpu"
66
+
67
+ @st.cache_resource
68
+ def load_model(_device):
69
+ """Load and cache the model to avoid reloading"""
70
+ global processor, model
71
+
72
+ # Authenticate with Hugging Face
73
+ if HF_TOKEN:
74
+ login(token=HF_TOKEN)
75
+
76
+ try:
77
+ logger.info(f"Loading SmolDocling model on {_device}...")
78
+ processor = AutoProcessor.from_pretrained(
79
+ "ds4sd/SmolDocling-256M-preview",
80
+ cache_dir=CACHE_DIR
81
+ )
82
+ model = AutoModelForVision2Seq.from_pretrained(
83
+ "ds4sd/SmolDocling-256M-preview",
84
+ torch_dtype=torch.float16 if _device.startswith("cuda") else torch.float32,
85
+ cache_dir=CACHE_DIR
86
+ ).to(_device)
87
+ logger.info("Model loaded successfully")
88
+ return processor, model
89
+ except Exception as e:
90
+ logger.error(f"Error loading model: {str(e)}")
91
+ raise
92
+
93
+ def optimize_image(image, max_size=1600):
94
+ """Optimize image size while maintaining aspect ratio"""
95
+ width, height = image.size
96
+ if max(width, height) > max_size:
97
+ if width > height:
98
+ new_width = max_size
99
+ new_height = int(height * (max_size / width))
100
+ else:
101
+ new_height = max_size
102
+ new_width = int(width * (max_size / height))
103
+ image = image.resize((new_width, new_height), Image.LANCZOS)
104
+ return image
105
+
106
+ def process_single_image(image, prompt_text="Convert this page to docling.", device="cpu", show_progress=None):
107
+ """Process a single image"""
108
+ global processor, model
109
+
110
+ # Optimize image
111
+ image = optimize_image(image)
112
+
113
+ start_time = time.time()
114
+
115
+ # Load the model if not already loaded
116
+ processor, model = load_model(device)
117
+
118
+ # Create input messages
119
+ messages = [
120
+ {
121
+ "role": "user",
122
+ "content": [
123
+ {"type": "image"},
124
+ {"type": "text", "text": prompt_text}
125
+ ]
126
+ },
127
+ ]
128
+
129
+ # Prepare inputs
130
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
131
+ inputs = processor(text=prompt, images=[image], return_tensors="pt")
132
+ inputs = inputs.to(device)
133
+
134
+ # Generate outputs
135
+ with torch.no_grad(): # Add this to save memory
136
+ generated_ids = model.generate(
137
+ **inputs,
138
+ max_new_tokens=1500, # Increased for better results
139
+ do_sample=False, # Deterministic generation
140
+ num_beams=1, # Simple beam search
141
+ temperature=1.0, # No temperature scaling
142
+ )
143
+
144
+ prompt_length = inputs.input_ids.shape[1]
145
+ trimmed_generated_ids = generated_ids[:, prompt_length:]
146
+ doctags = processor.batch_decode(
147
+ trimmed_generated_ids,
148
+ skip_special_tokens=False,
149
+ )[0].lstrip()
150
+
151
+ # Clean the output
152
+ doctags = doctags.replace("<end_of_utterance>", "").strip()
153
+
154
+ # Populate document
155
+ doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([doctags], [image])
156
+
157
+ # Create a docling document
158
+ doc = DoclingDocument(name="Document")
159
+ doc.load_from_doctags(doctags_doc)
160
+
161
+ # Export as markdown
162
+ md_content = doc.export_to_markdown()
163
+
164
+ # Export as HTML
165
+ html_content = doc.export_to_html()
166
+
167
+ # Get plain text
168
+ plain_text = doc.export_to_text()
169
+
170
+ processing_time = time.time() - start_time
171
+
172
+ return {
173
+ "doctags": doctags,
174
+ "markdown": md_content,
175
+ "html": html_content,
176
+ "text": plain_text,
177
+ "processing_time": processing_time
178
+ }
179
+
180
+ def process_batch(images, prompt_text, device, progress_bar=None):
181
+ """Process a batch of images with progress tracking"""
182
+ results = []
183
+ total = len(images)
184
+
185
+ for idx, image in enumerate(images):
186
+ if progress_bar:
187
+ progress_bar.progress((idx) / total, text=f"Processing image {idx+1}/{total}")
188
+
189
+ result = process_single_image(image, prompt_text, device)
190
+ results.append(result)
191
+
192
+ if progress_bar:
193
+ progress_bar.progress((idx + 1) / total, text=f"Processed {idx+1}/{total} images")
194
+
195
+ return results
196
+
197
+ def save_session_history(results):
198
+ """Save processing results to session history"""
199
+ if 'history' not in st.session_state:
200
+ st.session_state.history = []
201
+
202
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
203
+
204
+ for idx, result in enumerate(results):
205
+ st.session_state.history.append({
206
+ "id": len(st.session_state.history) + 1,
207
+ "timestamp": timestamp,
208
+ "type": "Image " + str(idx + 1),
209
+ "processing_time": result["processing_time"],
210
+ "result": result
211
+ })
212
+
213
+ def display_history():
214
+ """Display session history"""
215
+ if 'history' not in st.session_state or not st.session_state.history:
216
+ st.info("No processing history available")
217
+ return
218
+
219
+ st.subheader("Processing History")
220
+
221
+ for item in reversed(st.session_state.history):
222
+ with st.expander(f"#{item['id']} - {item['type']} ({item['timestamp']})"):
223
+ st.write(f"Processing time: {item['processing_time']:.2f} seconds")
224
+ tabs = st.tabs(["Markdown", "Text", "DocTags", "HTML"])
225
+
226
+ with tabs[0]:
227
+ st.markdown(item['result']['markdown'])
228
+ st.download_button(
229
+ "Download Markdown",
230
+ item['result']['markdown'],
231
+ file_name=f"output_{item['id']}.md"
232
+ )
233
+
234
+ with tabs[1]:
235
+ st.text_area("Plain Text", item['result']['text'], height=200)
236
+ st.download_button(
237
+ "Download Text",
238
+ item['result']['text'],
239
+ file_name=f"output_{item['id']}.txt"
240
+ )
241
+
242
+ with tabs[2]:
243
+ st.text_area("DocTags", item['result']['doctags'], height=200)
244
+ st.download_button(
245
+ "Download DocTags",
246
+ item['result']['doctags'],
247
+ file_name=f"output_{item['id']}.dt"
248
+ )
249
+
250
+ with tabs[3]:
251
+ st.code(item['result']['html'], language="html")
252
+ st.download_button(
253
+ "Download HTML",
254
+ item['result']['html'],
255
+ file_name=f"output_{item['id']}.html"
256
+ )
257
+
258
+ def main():
259
+ # App configuration
260
+ st.set_page_config(
261
+ page_title="SmolDocling OCR App",
262
+ page_icon="📄",
263
+ layout="wide",
264
+ initial_sidebar_state="expanded"
265
+ )
266
+
267
+ # Custom theme
268
+ st.markdown("""
269
+ <style>
270
+ .main-header {
271
+ font-size: 2.5rem;
272
+ margin-bottom: 0.5rem;
273
+ }
274
+ .sub-header {
275
+ font-size: 1.2rem;
276
+ color: #666;
277
+ margin-bottom: 2rem;
278
+ }
279
+ .stTabs [data-baseweb="tab-list"] {
280
+ gap: 2px;
281
+ }
282
+ .stTabs [data-baseweb="tab"] {
283
+ padding: 10px 16px;
284
+ background-color: #f0f2f6;
285
+ }
286
+ .stTabs [aria-selected="true"] {
287
+ background-color: #e6f0ff;
288
+ }
289
+ </style>
290
+ """, unsafe_allow_html=True)
291
+
292
+ # App header
293
+ st.markdown('<p class="main-header">SmolDocling OCR App</p>', unsafe_allow_html=True)
294
+ st.markdown('<p class="sub-header">Extract text from images using SmolDocling AI</p>', unsafe_allow_html=True)
295
+
296
+ # Check dependencies
297
+ missing_deps = check_dependencies()
298
+ if missing_deps:
299
+ st.error(f"Missing dependencies: {', '.join(missing_deps)}. Please install them to use this app.")
300
+ st.info("Install with: pip install " + " ".join(missing_deps))
301
+ st.stop()
302
+
303
+ # Initialize session state
304
+ if 'results' not in st.session_state:
305
+ st.session_state.results = []
306
+
307
+ # Create sidebar
308
+ with st.sidebar:
309
+ st.header("Configuration")
310
+
311
+ # Device selection
312
+ st.subheader("Processing Device")
313
+ available_devices = get_available_devices()
314
+ selected_device = st.selectbox(
315
+ "Select processing device",
316
+ available_devices,
317
+ index=0 if len(available_devices) == 1 else 1, # Default to CUDA if available
318
+ help="Choose the device for model inference. GPU (CUDA) is recommended for faster processing."
319
+ )
320
+ device = get_device_from_selection(selected_device)
321
+
322
+ # Model info
323
+ st.info(f"Selected device: {selected_device}")
324
+
325
+ if device == "cpu":
326
+ st.warning("⚠️ CPU processing may be slow. Select a GPU device if available for faster performance.")
327
+
328
+ # Memory management
329
+ if device.startswith("cuda"):
330
+ with st.expander("GPU Memory Management"):
331
+ st.write("Current GPU Memory Usage:")
332
+ if torch.cuda.is_available():
333
+ gpu_idx = int(device.split(":")[1]) if ":" in device else 0
334
+ allocated = torch.cuda.memory_allocated(gpu_idx) / (1024 ** 3)
335
+ reserved = torch.cuda.memory_reserved(gpu_idx) / (1024 ** 3)
336
+ st.progress(allocated / (torch.cuda.get_device_properties(gpu_idx).total_memory / (1024 ** 3)))
337
+ st.write(f"Allocated: {allocated:.2f} GB")
338
+ st.write(f"Reserved: {reserved:.2f} GB")
339
+
340
+ if st.button("Clear GPU Cache"):
341
+ torch.cuda.empty_cache()
342
+ st.success("GPU cache cleared")
343
+
344
+ # Upload options
345
+ st.subheader("Upload Options")
346
+ upload_option = st.radio("Choose upload option:", ["Single Image", "Multiple Images"])
347
+
348
+ # Advanced options
349
+ with st.expander("Advanced Options"):
350
+ task_type = st.selectbox(
351
+ "Select task type",
352
+ [
353
+ "Convert this page to docling.",
354
+ "Convert this table to OTSL.",
355
+ "Convert code to text.",
356
+ "Convert formula to latex.",
357
+ "Convert chart to OTSL.",
358
+ "Extract all section header elements on the page."
359
+ ]
360
+ )
361
+
362
+ custom_prompt = st.text_area(
363
+ "Custom prompt (optional)",
364
+ value="",
365
+ help="Provide a custom prompt if needed. Leave empty to use the selected task type."
366
+ )
367
+
368
+ max_image_size = st.slider(
369
+ "Max image dimension (pixels)",
370
+ min_value=800,
371
+ max_value=3200,
372
+ value=1600,
373
+ step=100,
374
+ help="Larger values may improve OCR quality but use more memory"
375
+ )
376
+
377
+ final_prompt = custom_prompt if custom_prompt else task_type
378
+
379
+ # Upload controls
380
+ st.subheader("Upload Image(s)")
381
+ if upload_option == "Single Image":
382
+ uploaded_file = st.file_uploader("Upload image", type=["jpg", "jpeg", "png", "pdf"])
383
+
384
+ if uploaded_file is not None:
385
+ try:
386
+ image = Image.open(uploaded_file).convert("RGB")
387
+ st.image(image, caption="Uploaded Image", width=250)
388
+ except Exception as e:
389
+ st.error(f"Error loading image: {str(e)}")
390
+ else:
391
+ uploaded_files = st.file_uploader(
392
+ "Upload multiple images",
393
+ type=["jpg", "jpeg", "png"],
394
+ accept_multiple_files=True
395
+ )
396
+
397
+ if uploaded_files:
398
+ st.success(f"{len(uploaded_files)} images uploaded")
399
+
400
+ # Process button
401
+ if (upload_option == "Single Image" and 'uploaded_file' in locals() and uploaded_file is not None) or \
402
+ (upload_option == "Multiple Images" and 'uploaded_files' in locals() and uploaded_files):
403
+ process_button = st.button("Process Image(s)", type="primary")
404
+
405
+ # History button
406
+ st.subheader("History")
407
+ if st.button("Show Processing History"):
408
+ st.session_state.show_history = True
409
+
410
+ # About section
411
+ with st.expander("About SmolDocling OCR"):
412
+ st.write("""
413
+ This app uses SmolDocling, a powerful OCR model for document understanding from Hugging Face Hub.
414
+
415
+ The app extracts DocTags format and converts it to Markdown, HTML, and plain text for easy reading.
416
+
417
+ Available tasks:
418
+ - Convert pages to DocTags (general OCR)
419
+ - Convert tables to OTSL
420
+ - Convert code snippets to text
421
+ - Convert formulas to LaTeX
422
+ - Convert charts to OTSL
423
+ - Extract section headers
424
+ """)
425
+
426
+ # Main content area
427
+ if 'show_history' in st.session_state and st.session_state.show_history:
428
+ display_history()
429
+ st.session_state.show_history = False
430
+ elif upload_option == "Single Image" and 'uploaded_file' in locals() and uploaded_file is not None and process_button:
431
+ with st.spinner("Processing image..."):
432
+ try:
433
+ progress_bar = st.progress(0, text="Preparing to process...")
434
+
435
+ # Update global optimization settings
436
+ optimize_image.func_defaults = (max_image_size,)
437
+
438
+ result = process_single_image(image, final_prompt, device)
439
+ st.session_state.results = [result]
440
+
441
+ # Save to history
442
+ save_session_history(st.session_state.results)
443
+
444
+ progress_bar.progress(1.0, text="Processing complete!")
445
+
446
+ # Display results
447
+ tabs = st.tabs(["Markdown", "Text", "DocTags", "HTML"])
448
+
449
+ with tabs[0]:
450
+ st.subheader("Markdown Output")
451
+ st.markdown(result["markdown"])
452
+ st.download_button(
453
+ "Download Markdown",
454
+ result["markdown"],
455
+ file_name="output.md"
456
+ )
457
+
458
+ with tabs[1]:
459
+ st.subheader("Plain Text Output")
460
+ st.text_area("Extracted Text", result["text"], height=300)
461
+ st.download_button(
462
+ "Download Text",
463
+ result["text"],
464
+ file_name="output.txt"
465
+ )
466
+
467
+ with tabs[2]:
468
+ st.subheader("DocTags Output")
469
+ st.text_area("DocTags", result["doctags"], height=300)
470
+ st.download_button(
471
+ "Download DocTags",
472
+ result["doctags"],
473
+ file_name="output.dt"
474
+ )
475
+
476
+ with tabs[3]:
477
+ st.subheader("HTML Output")
478
+ st.code(result["html"], language="html")
479
+ st.download_button(
480
+ "Download HTML",
481
+ result["html"],
482
+ file_name="output.html"
483
+ )
484
+
485
+ st.success(f"Processing completed in {result['processing_time']:.2f} seconds on {selected_device}")
486
+ except Exception as e:
487
+ st.error(f"Error processing image: {str(e)}")
488
+ logger.error(f"Error processing image: {str(e)}", exc_info=True)
489
+
490
+ elif upload_option == "Multiple Images" and 'uploaded_files' in locals() and uploaded_files and process_button:
491
+ try:
492
+ images = [Image.open(file).convert("RGB") for file in uploaded_files]
493
+
494
+ if len(images) > 0:
495
+ with st.spinner(f"Processing {len(images)} images..."):
496
+ progress_bar = st.progress(0, text="Preparing to process...")
497
+
498
+ # Update global optimization settings
499
+ optimize_image.func_defaults = (max_image_size,)
500
+
501
+ results = process_batch(images, final_prompt, device, progress_bar)
502
+ st.session_state.results = results
503
+
504
+ # Save to history
505
+ save_session_history(results)
506
+
507
+ progress_bar.progress(1.0, text="Processing complete!")
508
+
509
+ # Display results
510
+ st.subheader("Processing Results")
511
+
512
+ total_time = sum(result["processing_time"] for result in results)
513
+ avg_time = total_time / len(results)
514
+
515
+ st.write(f"Total processing time: {total_time:.2f} seconds on {selected_device}")
516
+ st.write(f"Average processing time: {avg_time:.2f} seconds per image")
517
+
518
+ # Create tabs for each image
519
+ for idx, (result, image) in enumerate(zip(results, images)):
520
+ with st.expander(f"Image {idx+1} Results"):
521
+ col1, col2 = st.columns([1, 2])
522
+
523
+ with col1:
524
+ st.image(image, caption=f"Image {idx+1}", width=250)
525
+ st.write(f"Processing time: {result['processing_time']:.2f} seconds")
526
+
527
+ with col2:
528
+ inner_tabs = st.tabs(["Markdown", "Text", "DocTags", "HTML"])
529
+
530
+ with inner_tabs[0]:
531
+ st.markdown(result["markdown"])
532
+ st.download_button(
533
+ f"Download Markdown",
534
+ result["markdown"],
535
+ file_name=f"output_{idx+1}.md"
536
+ )
537
+
538
+ with inner_tabs[1]:
539
+ st.text_area("Plain Text", result["text"], height=200)
540
+ st.download_button(
541
+ f"Download Text",
542
+ result["text"],
543
+ file_name=f"output_{idx+1}.txt"
544
+ )
545
+
546
+ with inner_tabs[2]:
547
+ st.text_area("DocTags", result["doctags"], height=200)
548
+ st.download_button(
549
+ f"Download DocTags",
550
+ result["doctags"],
551
+ file_name=f"output_{idx+1}.dt"
552
+ )
553
+
554
+ with inner_tabs[3]:
555
+ st.code(result["html"], language="html")
556
+ st.download_button(
557
+ f"Download HTML",
558
+ result["html"],
559
+ file_name=f"output_{idx+1}.html"
560
+ )
561
+
562
+ st.success(f"All images processed successfully")
563
+ except Exception as e:
564
+ st.error(f"Error processing images: {str(e)}")
565
+ logger.error(f"Error processing images: {str(e)}", exc_info=True)
566
+
567
+ # Display a welcome message if no image has been uploaded
568
+ if ('uploaded_file' not in locals() or uploaded_file is None) and \
569
+ ('uploaded_files' not in locals() or not uploaded_files):
570
+ st.info("👈 Upload an image using the sidebar to get started")
571
+
572
+
573
+ if __name__ == "__main__":
574
+ main()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ accelerate
4
+ transformers
5
+ docling-core
6
+ huggingface_hub
7
+ Pillow
8
+ python-dotenv