import streamlit as st import os import time import torch import tempfile from PIL import Image from dotenv import load_dotenv import logging from datetime import datetime # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Load environment variables load_dotenv() HF_TOKEN = os.getenv("HF_TOKEN") CACHE_DIR = os.getenv("CACHE_DIR", os.path.join(tempfile.gettempdir(), "smoldocling_cache")) # Ensure cache directory exists os.makedirs(CACHE_DIR, exist_ok=True) # Import for Transformers approach try: from transformers import AutoProcessor, AutoModelForVision2Seq from huggingface_hub import login transformers_available = True except ImportError: transformers_available = False try: from docling_core.types.doc import DoclingDocument from docling_core.types.doc.document import DocTagsDocument docling_available = True except ImportError: docling_available = False # Global variables for model caching processor = None model = None def check_dependencies(): """Check if all required dependencies are installed""" missing = [] if not transformers_available: missing.append("transformers huggingface_hub") if not docling_available: missing.append("docling-core") return missing def get_available_devices(): """Get available processing devices""" devices = ["cpu"] if torch.cuda.is_available(): cuda_count = torch.cuda.device_count() for i in range(cuda_count): devices.append(f"cuda:{i} ({torch.cuda.get_device_name(i)})") return devices def get_device_from_selection(selection): """Convert user-friendly device selection to torch device""" if selection.startswith("cuda:"): return selection.split(" ")[0] # Extract just the "cuda:X" part return "cpu" @st.cache_resource def load_model(_device): """Load and cache the model to avoid reloading""" global processor, model # Authenticate with Hugging Face if HF_TOKEN: login(token=HF_TOKEN) try: logger.info(f"Loading SmolDocling model on {_device}...") processor = AutoProcessor.from_pretrained( "ds4sd/SmolDocling-256M-preview", cache_dir=CACHE_DIR ) model = AutoModelForVision2Seq.from_pretrained( "ds4sd/SmolDocling-256M-preview", torch_dtype=torch.float16 if _device.startswith("cuda") else torch.float32, cache_dir=CACHE_DIR ).to(_device) logger.info("Model loaded successfully") return processor, model except Exception as e: logger.error(f"Error loading model: {str(e)}") raise def optimize_image(image, max_size=1600): """Optimize image size while maintaining aspect ratio""" width, height = image.size if max(width, height) > max_size: if width > height: new_width = max_size new_height = int(height * (max_size / width)) else: new_height = max_size new_width = int(width * (max_size / height)) image = image.resize((new_width, new_height), Image.LANCZOS) return image def process_single_image(image, prompt_text="Convert this page to docling.", device="cpu", show_progress=None): """Process a single image""" global processor, model # Optimize image image = optimize_image(image) start_time = time.time() # Load the model if not already loaded processor, model = load_model(device) # Create input messages messages = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": prompt_text} ] }, ] # Prepare inputs prompt = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = processor(text=prompt, images=[image], return_tensors="pt") inputs = inputs.to(device) # Generate outputs with torch.no_grad(): # Add this to save memory generated_ids = model.generate( **inputs, max_new_tokens=1500, # Increased for better results do_sample=False, # Deterministic generation num_beams=1, # Simple beam search temperature=1.0, # No temperature scaling ) prompt_length = inputs.input_ids.shape[1] trimmed_generated_ids = generated_ids[:, prompt_length:] doctags = processor.batch_decode( trimmed_generated_ids, skip_special_tokens=False, )[0].lstrip() # Clean the output doctags = doctags.replace("", "").strip() # Populate document doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([doctags], [image]) # Create a docling document doc = DoclingDocument(name="Document") doc.load_from_doctags(doctags_doc) # Export as markdown md_content = doc.export_to_markdown() # Export as HTML html_content = doc.export_to_html() # Get plain text plain_text = doc.export_to_text() processing_time = time.time() - start_time return { "doctags": doctags, "markdown": md_content, "html": html_content, "text": plain_text, "processing_time": processing_time } def process_batch(images, prompt_text, device, progress_bar=None): """Process a batch of images with progress tracking""" results = [] total = len(images) for idx, image in enumerate(images): if progress_bar: progress_bar.progress((idx) / total, text=f"Processing image {idx+1}/{total}") result = process_single_image(image, prompt_text, device) results.append(result) if progress_bar: progress_bar.progress((idx + 1) / total, text=f"Processed {idx+1}/{total} images") return results def save_session_history(results): """Save processing results to session history""" if 'history' not in st.session_state: st.session_state.history = [] timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") for idx, result in enumerate(results): st.session_state.history.append({ "id": len(st.session_state.history) + 1, "timestamp": timestamp, "type": "Image " + str(idx + 1), "processing_time": result["processing_time"], "result": result }) def display_history(): """Display session history""" if 'history' not in st.session_state or not st.session_state.history: st.info("No processing history available") return st.subheader("Processing History") for item in reversed(st.session_state.history): with st.expander(f"#{item['id']} - {item['type']} ({item['timestamp']})"): st.write(f"Processing time: {item['processing_time']:.2f} seconds") tabs = st.tabs(["Markdown", "Text", "DocTags", "HTML"]) with tabs[0]: st.markdown(item['result']['markdown']) st.download_button( "Download Markdown", item['result']['markdown'], file_name=f"output_{item['id']}.md" ) with tabs[1]: st.text_area("Plain Text", item['result']['text'], height=200) st.download_button( "Download Text", item['result']['text'], file_name=f"output_{item['id']}.txt" ) with tabs[2]: st.text_area("DocTags", item['result']['doctags'], height=200) st.download_button( "Download DocTags", item['result']['doctags'], file_name=f"output_{item['id']}.dt" ) with tabs[3]: st.code(item['result']['html'], language="html") st.download_button( "Download HTML", item['result']['html'], file_name=f"output_{item['id']}.html" ) def main(): # App configuration st.set_page_config( page_title="SmolDocling OCR App", page_icon="📄", layout="wide", initial_sidebar_state="expanded" ) # Custom theme st.markdown(""" """, unsafe_allow_html=True) # App header st.markdown('

SmolDocling OCR App

', unsafe_allow_html=True) st.markdown('

Extract text from images using SmolDocling AI

', unsafe_allow_html=True) # Check dependencies missing_deps = check_dependencies() if missing_deps: st.error(f"Missing dependencies: {', '.join(missing_deps)}. Please install them to use this app.") st.info("Install with: pip install " + " ".join(missing_deps)) st.stop() # Initialize session state if 'results' not in st.session_state: st.session_state.results = [] # Create sidebar with st.sidebar: st.header("Configuration") # Device selection st.subheader("Processing Device") available_devices = get_available_devices() selected_device = st.selectbox( "Select processing device", available_devices, index=0 if len(available_devices) == 1 else 1, # Default to CUDA if available help="Choose the device for model inference. GPU (CUDA) is recommended for faster processing." ) device = get_device_from_selection(selected_device) # Model info st.info(f"Selected device: {selected_device}") if device == "cpu": st.warning("⚠️ CPU processing may be slow. Select a GPU device if available for faster performance.") # Memory management if device.startswith("cuda"): with st.expander("GPU Memory Management"): st.write("Current GPU Memory Usage:") if torch.cuda.is_available(): gpu_idx = int(device.split(":")[1]) if ":" in device else 0 allocated = torch.cuda.memory_allocated(gpu_idx) / (1024 ** 3) reserved = torch.cuda.memory_reserved(gpu_idx) / (1024 ** 3) st.progress(allocated / (torch.cuda.get_device_properties(gpu_idx).total_memory / (1024 ** 3))) st.write(f"Allocated: {allocated:.2f} GB") st.write(f"Reserved: {reserved:.2f} GB") if st.button("Clear GPU Cache"): torch.cuda.empty_cache() st.success("GPU cache cleared") # Upload options st.subheader("Upload Options") upload_option = st.radio("Choose upload option:", ["Single Image", "Multiple Images"]) # Advanced options with st.expander("Advanced Options"): task_type = st.selectbox( "Select task type", [ "Convert this page to docling.", "Convert this table to OTSL.", "Convert code to text.", "Convert formula to latex.", "Convert chart to OTSL.", "Extract all section header elements on the page." ] ) custom_prompt = st.text_area( "Custom prompt (optional)", value="", help="Provide a custom prompt if needed. Leave empty to use the selected task type." ) max_image_size = st.slider( "Max image dimension (pixels)", min_value=800, max_value=3200, value=1600, step=100, help="Larger values may improve OCR quality but use more memory" ) final_prompt = custom_prompt if custom_prompt else task_type # Upload controls st.subheader("Upload Image(s)") if upload_option == "Single Image": uploaded_file = st.file_uploader("Upload image", type=["jpg", "jpeg", "png", "pdf"]) if uploaded_file is not None: try: image = Image.open(uploaded_file).convert("RGB") st.image(image, caption="Uploaded Image", width=250) except Exception as e: st.error(f"Error loading image: {str(e)}") else: uploaded_files = st.file_uploader( "Upload multiple images", type=["jpg", "jpeg", "png"], accept_multiple_files=True ) if uploaded_files: st.success(f"{len(uploaded_files)} images uploaded") # Process button if (upload_option == "Single Image" and 'uploaded_file' in locals() and uploaded_file is not None) or \ (upload_option == "Multiple Images" and 'uploaded_files' in locals() and uploaded_files): process_button = st.button("Process Image(s)", type="primary") # History button st.subheader("History") if st.button("Show Processing History"): st.session_state.show_history = True # About section with st.expander("About SmolDocling OCR"): st.write(""" This app uses SmolDocling, a powerful OCR model for document understanding from Hugging Face Hub. The app extracts DocTags format and converts it to Markdown, HTML, and plain text for easy reading. Available tasks: - Convert pages to DocTags (general OCR) - Convert tables to OTSL - Convert code snippets to text - Convert formulas to LaTeX - Convert charts to OTSL - Extract section headers """) # Main content area if 'show_history' in st.session_state and st.session_state.show_history: display_history() st.session_state.show_history = False elif upload_option == "Single Image" and 'uploaded_file' in locals() and uploaded_file is not None and process_button: with st.spinner("Processing image..."): try: progress_bar = st.progress(0, text="Preparing to process...") # Update global optimization settings optimize_image.func_defaults = (max_image_size,) result = process_single_image(image, final_prompt, device) st.session_state.results = [result] # Save to history save_session_history(st.session_state.results) progress_bar.progress(1.0, text="Processing complete!") # Display results tabs = st.tabs(["Markdown", "Text", "DocTags", "HTML"]) with tabs[0]: st.subheader("Markdown Output") st.markdown(result["markdown"]) st.download_button( "Download Markdown", result["markdown"], file_name="output.md" ) with tabs[1]: st.subheader("Plain Text Output") st.text_area("Extracted Text", result["text"], height=300) st.download_button( "Download Text", result["text"], file_name="output.txt" ) with tabs[2]: st.subheader("DocTags Output") st.text_area("DocTags", result["doctags"], height=300) st.download_button( "Download DocTags", result["doctags"], file_name="output.dt" ) with tabs[3]: st.subheader("HTML Output") st.code(result["html"], language="html") st.download_button( "Download HTML", result["html"], file_name="output.html" ) st.success(f"Processing completed in {result['processing_time']:.2f} seconds on {selected_device}") except Exception as e: st.error(f"Error processing image: {str(e)}") logger.error(f"Error processing image: {str(e)}", exc_info=True) elif upload_option == "Multiple Images" and 'uploaded_files' in locals() and uploaded_files and process_button: try: images = [Image.open(file).convert("RGB") for file in uploaded_files] if len(images) > 0: with st.spinner(f"Processing {len(images)} images..."): progress_bar = st.progress(0, text="Preparing to process...") # Update global optimization settings optimize_image.func_defaults = (max_image_size,) results = process_batch(images, final_prompt, device, progress_bar) st.session_state.results = results # Save to history save_session_history(results) progress_bar.progress(1.0, text="Processing complete!") # Display results st.subheader("Processing Results") total_time = sum(result["processing_time"] for result in results) avg_time = total_time / len(results) st.write(f"Total processing time: {total_time:.2f} seconds on {selected_device}") st.write(f"Average processing time: {avg_time:.2f} seconds per image") # Create tabs for each image for idx, (result, image) in enumerate(zip(results, images)): with st.expander(f"Image {idx+1} Results"): col1, col2 = st.columns([1, 2]) with col1: st.image(image, caption=f"Image {idx+1}", width=250) st.write(f"Processing time: {result['processing_time']:.2f} seconds") with col2: inner_tabs = st.tabs(["Markdown", "Text", "DocTags", "HTML"]) with inner_tabs[0]: st.markdown(result["markdown"]) st.download_button( f"Download Markdown", result["markdown"], file_name=f"output_{idx+1}.md" ) with inner_tabs[1]: st.text_area("Plain Text", result["text"], height=200) st.download_button( f"Download Text", result["text"], file_name=f"output_{idx+1}.txt" ) with inner_tabs[2]: st.text_area("DocTags", result["doctags"], height=200) st.download_button( f"Download DocTags", result["doctags"], file_name=f"output_{idx+1}.dt" ) with inner_tabs[3]: st.code(result["html"], language="html") st.download_button( f"Download HTML", result["html"], file_name=f"output_{idx+1}.html" ) st.success(f"All images processed successfully") except Exception as e: st.error(f"Error processing images: {str(e)}") logger.error(f"Error processing images: {str(e)}", exc_info=True) # Display a welcome message if no image has been uploaded if ('uploaded_file' not in locals() or uploaded_file is None) and \ ('uploaded_files' not in locals() or not uploaded_files): st.info("👈 Upload an image using the sidebar to get started") if __name__ == "__main__": main()