Spaces:
Running
Running
import streamlit as st | |
import os | |
import time | |
import torch | |
import tempfile | |
from PIL import Image | |
from dotenv import load_dotenv | |
import logging | |
from datetime import datetime | |
# Set up logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Load environment variables | |
load_dotenv() | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
CACHE_DIR = os.getenv("CACHE_DIR", os.path.join(tempfile.gettempdir(), "smoldocling_cache")) | |
# Ensure cache directory exists | |
os.makedirs(CACHE_DIR, exist_ok=True) | |
# Import for Transformers approach | |
try: | |
from transformers import AutoProcessor, AutoModelForVision2Seq | |
from huggingface_hub import login | |
transformers_available = True | |
except ImportError: | |
transformers_available = False | |
try: | |
from docling_core.types.doc import DoclingDocument | |
from docling_core.types.doc.document import DocTagsDocument | |
docling_available = True | |
except ImportError: | |
docling_available = False | |
# Global variables for model caching | |
processor = None | |
model = None | |
def check_dependencies(): | |
"""Check if all required dependencies are installed""" | |
missing = [] | |
if not transformers_available: | |
missing.append("transformers huggingface_hub") | |
if not docling_available: | |
missing.append("docling-core") | |
return missing | |
def get_available_devices(): | |
"""Get available processing devices""" | |
devices = ["cpu"] | |
if torch.cuda.is_available(): | |
cuda_count = torch.cuda.device_count() | |
for i in range(cuda_count): | |
devices.append(f"cuda:{i} ({torch.cuda.get_device_name(i)})") | |
return devices | |
def get_device_from_selection(selection): | |
"""Convert user-friendly device selection to torch device""" | |
if selection.startswith("cuda:"): | |
return selection.split(" ")[0] # Extract just the "cuda:X" part | |
return "cpu" | |
def load_model(_device): | |
"""Load and cache the model to avoid reloading""" | |
global processor, model | |
# Authenticate with Hugging Face | |
if HF_TOKEN: | |
login(token=HF_TOKEN) | |
try: | |
logger.info(f"Loading SmolDocling model on {_device}...") | |
processor = AutoProcessor.from_pretrained( | |
"ds4sd/SmolDocling-256M-preview", | |
cache_dir=CACHE_DIR | |
) | |
model = AutoModelForVision2Seq.from_pretrained( | |
"ds4sd/SmolDocling-256M-preview", | |
torch_dtype=torch.float16 if _device.startswith("cuda") else torch.float32, | |
cache_dir=CACHE_DIR | |
).to(_device) | |
logger.info("Model loaded successfully") | |
return processor, model | |
except Exception as e: | |
logger.error(f"Error loading model: {str(e)}") | |
raise | |
def optimize_image(image, max_size=1600): | |
"""Optimize image size while maintaining aspect ratio""" | |
width, height = image.size | |
if max(width, height) > max_size: | |
if width > height: | |
new_width = max_size | |
new_height = int(height * (max_size / width)) | |
else: | |
new_height = max_size | |
new_width = int(width * (max_size / height)) | |
image = image.resize((new_width, new_height), Image.LANCZOS) | |
return image | |
def process_single_image(image, prompt_text="Convert this page to docling.", device="cpu", show_progress=None): | |
"""Process a single image""" | |
global processor, model | |
# Optimize image | |
image = optimize_image(image) | |
start_time = time.time() | |
# Load the model if not already loaded | |
processor, model = load_model(device) | |
# Create input messages | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image"}, | |
{"type": "text", "text": prompt_text} | |
] | |
}, | |
] | |
# Prepare inputs | |
prompt = processor.apply_chat_template(messages, add_generation_prompt=True) | |
inputs = processor(text=prompt, images=[image], return_tensors="pt") | |
inputs = inputs.to(device) | |
# Generate outputs | |
with torch.no_grad(): # Add this to save memory | |
generated_ids = model.generate( | |
**inputs, | |
max_new_tokens=1500, # Increased for better results | |
do_sample=False, # Deterministic generation | |
num_beams=1, # Simple beam search | |
temperature=1.0, # No temperature scaling | |
) | |
prompt_length = inputs.input_ids.shape[1] | |
trimmed_generated_ids = generated_ids[:, prompt_length:] | |
doctags = processor.batch_decode( | |
trimmed_generated_ids, | |
skip_special_tokens=False, | |
)[0].lstrip() | |
# Clean the output | |
doctags = doctags.replace("<end_of_utterance>", "").strip() | |
# Populate document | |
doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([doctags], [image]) | |
# Create a docling document | |
doc = DoclingDocument(name="Document") | |
doc.load_from_doctags(doctags_doc) | |
# Export as markdown | |
md_content = doc.export_to_markdown() | |
# Export as HTML | |
html_content = doc.export_to_html() | |
# Get plain text | |
plain_text = doc.export_to_text() | |
processing_time = time.time() - start_time | |
return { | |
"doctags": doctags, | |
"markdown": md_content, | |
"html": html_content, | |
"text": plain_text, | |
"processing_time": processing_time | |
} | |
def process_batch(images, prompt_text, device, progress_bar=None): | |
"""Process a batch of images with progress tracking""" | |
results = [] | |
total = len(images) | |
for idx, image in enumerate(images): | |
if progress_bar: | |
progress_bar.progress((idx) / total, text=f"Processing image {idx+1}/{total}") | |
result = process_single_image(image, prompt_text, device) | |
results.append(result) | |
if progress_bar: | |
progress_bar.progress((idx + 1) / total, text=f"Processed {idx+1}/{total} images") | |
return results | |
def save_session_history(results): | |
"""Save processing results to session history""" | |
if 'history' not in st.session_state: | |
st.session_state.history = [] | |
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
for idx, result in enumerate(results): | |
st.session_state.history.append({ | |
"id": len(st.session_state.history) + 1, | |
"timestamp": timestamp, | |
"type": "Image " + str(idx + 1), | |
"processing_time": result["processing_time"], | |
"result": result | |
}) | |
def display_history(): | |
"""Display session history""" | |
if 'history' not in st.session_state or not st.session_state.history: | |
st.info("No processing history available") | |
return | |
st.subheader("Processing History") | |
for item in reversed(st.session_state.history): | |
with st.expander(f"#{item['id']} - {item['type']} ({item['timestamp']})"): | |
st.write(f"Processing time: {item['processing_time']:.2f} seconds") | |
tabs = st.tabs(["Markdown", "Text", "DocTags", "HTML"]) | |
with tabs[0]: | |
st.markdown(item['result']['markdown']) | |
st.download_button( | |
"Download Markdown", | |
item['result']['markdown'], | |
file_name=f"output_{item['id']}.md" | |
) | |
with tabs[1]: | |
st.text_area("Plain Text", item['result']['text'], height=200) | |
st.download_button( | |
"Download Text", | |
item['result']['text'], | |
file_name=f"output_{item['id']}.txt" | |
) | |
with tabs[2]: | |
st.text_area("DocTags", item['result']['doctags'], height=200) | |
st.download_button( | |
"Download DocTags", | |
item['result']['doctags'], | |
file_name=f"output_{item['id']}.dt" | |
) | |
with tabs[3]: | |
st.code(item['result']['html'], language="html") | |
st.download_button( | |
"Download HTML", | |
item['result']['html'], | |
file_name=f"output_{item['id']}.html" | |
) | |
def main(): | |
# App configuration | |
st.set_page_config( | |
page_title="SmolDocling OCR App", | |
page_icon="π", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# Custom theme | |
st.markdown(""" | |
<style> | |
.main-header { | |
font-size: 2.5rem; | |
margin-bottom: 0.5rem; | |
} | |
.sub-header { | |
font-size: 1.2rem; | |
color: #666; | |
margin-bottom: 2rem; | |
} | |
.stTabs [data-baseweb="tab-list"] { | |
gap: 2px; | |
} | |
.stTabs [data-baseweb="tab"] { | |
padding: 10px 16px; | |
background-color: #f0f2f6; | |
} | |
.stTabs [aria-selected="true"] { | |
background-color: #e6f0ff; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# App header | |
st.markdown('<p class="main-header">SmolDocling OCR App</p>', unsafe_allow_html=True) | |
st.markdown('<p class="sub-header">Extract text from images using SmolDocling AI</p>', unsafe_allow_html=True) | |
# Check dependencies | |
missing_deps = check_dependencies() | |
if missing_deps: | |
st.error(f"Missing dependencies: {', '.join(missing_deps)}. Please install them to use this app.") | |
st.info("Install with: pip install " + " ".join(missing_deps)) | |
st.stop() | |
# Initialize session state | |
if 'results' not in st.session_state: | |
st.session_state.results = [] | |
# Create sidebar | |
with st.sidebar: | |
st.header("Configuration") | |
# Device selection | |
st.subheader("Processing Device") | |
available_devices = get_available_devices() | |
selected_device = st.selectbox( | |
"Select processing device", | |
available_devices, | |
index=0 if len(available_devices) == 1 else 1, # Default to CUDA if available | |
help="Choose the device for model inference. GPU (CUDA) is recommended for faster processing." | |
) | |
device = get_device_from_selection(selected_device) | |
# Model info | |
st.info(f"Selected device: {selected_device}") | |
if device == "cpu": | |
st.warning("β οΈ CPU processing may be slow. Select a GPU device if available for faster performance.") | |
# Memory management | |
if device.startswith("cuda"): | |
with st.expander("GPU Memory Management"): | |
st.write("Current GPU Memory Usage:") | |
if torch.cuda.is_available(): | |
gpu_idx = int(device.split(":")[1]) if ":" in device else 0 | |
allocated = torch.cuda.memory_allocated(gpu_idx) / (1024 ** 3) | |
reserved = torch.cuda.memory_reserved(gpu_idx) / (1024 ** 3) | |
st.progress(allocated / (torch.cuda.get_device_properties(gpu_idx).total_memory / (1024 ** 3))) | |
st.write(f"Allocated: {allocated:.2f} GB") | |
st.write(f"Reserved: {reserved:.2f} GB") | |
if st.button("Clear GPU Cache"): | |
torch.cuda.empty_cache() | |
st.success("GPU cache cleared") | |
# Upload options | |
st.subheader("Upload Options") | |
upload_option = st.radio("Choose upload option:", ["Single Image", "Multiple Images"]) | |
# Advanced options | |
with st.expander("Advanced Options"): | |
task_type = st.selectbox( | |
"Select task type", | |
[ | |
"Convert this page to docling.", | |
"Convert this table to OTSL.", | |
"Convert code to text.", | |
"Convert formula to latex.", | |
"Convert chart to OTSL.", | |
"Extract all section header elements on the page." | |
] | |
) | |
custom_prompt = st.text_area( | |
"Custom prompt (optional)", | |
value="", | |
help="Provide a custom prompt if needed. Leave empty to use the selected task type." | |
) | |
max_image_size = st.slider( | |
"Max image dimension (pixels)", | |
min_value=800, | |
max_value=3200, | |
value=1600, | |
step=100, | |
help="Larger values may improve OCR quality but use more memory" | |
) | |
final_prompt = custom_prompt if custom_prompt else task_type | |
# Upload controls | |
st.subheader("Upload Image(s)") | |
if upload_option == "Single Image": | |
uploaded_file = st.file_uploader("Upload image", type=["jpg", "jpeg", "png", "pdf"]) | |
if uploaded_file is not None: | |
try: | |
image = Image.open(uploaded_file).convert("RGB") | |
st.image(image, caption="Uploaded Image", width=250) | |
except Exception as e: | |
st.error(f"Error loading image: {str(e)}") | |
else: | |
uploaded_files = st.file_uploader( | |
"Upload multiple images", | |
type=["jpg", "jpeg", "png"], | |
accept_multiple_files=True | |
) | |
if uploaded_files: | |
st.success(f"{len(uploaded_files)} images uploaded") | |
# Process button | |
if (upload_option == "Single Image" and 'uploaded_file' in locals() and uploaded_file is not None) or \ | |
(upload_option == "Multiple Images" and 'uploaded_files' in locals() and uploaded_files): | |
process_button = st.button("Process Image(s)", type="primary") | |
# History button | |
st.subheader("History") | |
if st.button("Show Processing History"): | |
st.session_state.show_history = True | |
# About section | |
with st.expander("About SmolDocling OCR"): | |
st.write(""" | |
This app uses SmolDocling, a powerful OCR model for document understanding from Hugging Face Hub. | |
The app extracts DocTags format and converts it to Markdown, HTML, and plain text for easy reading. | |
Available tasks: | |
- Convert pages to DocTags (general OCR) | |
- Convert tables to OTSL | |
- Convert code snippets to text | |
- Convert formulas to LaTeX | |
- Convert charts to OTSL | |
- Extract section headers | |
""") | |
# Main content area | |
if 'show_history' in st.session_state and st.session_state.show_history: | |
display_history() | |
st.session_state.show_history = False | |
elif upload_option == "Single Image" and 'uploaded_file' in locals() and uploaded_file is not None and process_button: | |
with st.spinner("Processing image..."): | |
try: | |
progress_bar = st.progress(0, text="Preparing to process...") | |
# Update global optimization settings | |
optimize_image.func_defaults = (max_image_size,) | |
result = process_single_image(image, final_prompt, device) | |
st.session_state.results = [result] | |
# Save to history | |
save_session_history(st.session_state.results) | |
progress_bar.progress(1.0, text="Processing complete!") | |
# Display results | |
tabs = st.tabs(["Markdown", "Text", "DocTags", "HTML"]) | |
with tabs[0]: | |
st.subheader("Markdown Output") | |
st.markdown(result["markdown"]) | |
st.download_button( | |
"Download Markdown", | |
result["markdown"], | |
file_name="output.md" | |
) | |
with tabs[1]: | |
st.subheader("Plain Text Output") | |
st.text_area("Extracted Text", result["text"], height=300) | |
st.download_button( | |
"Download Text", | |
result["text"], | |
file_name="output.txt" | |
) | |
with tabs[2]: | |
st.subheader("DocTags Output") | |
st.text_area("DocTags", result["doctags"], height=300) | |
st.download_button( | |
"Download DocTags", | |
result["doctags"], | |
file_name="output.dt" | |
) | |
with tabs[3]: | |
st.subheader("HTML Output") | |
st.code(result["html"], language="html") | |
st.download_button( | |
"Download HTML", | |
result["html"], | |
file_name="output.html" | |
) | |
st.success(f"Processing completed in {result['processing_time']:.2f} seconds on {selected_device}") | |
except Exception as e: | |
st.error(f"Error processing image: {str(e)}") | |
logger.error(f"Error processing image: {str(e)}", exc_info=True) | |
elif upload_option == "Multiple Images" and 'uploaded_files' in locals() and uploaded_files and process_button: | |
try: | |
images = [Image.open(file).convert("RGB") for file in uploaded_files] | |
if len(images) > 0: | |
with st.spinner(f"Processing {len(images)} images..."): | |
progress_bar = st.progress(0, text="Preparing to process...") | |
# Update global optimization settings | |
optimize_image.func_defaults = (max_image_size,) | |
results = process_batch(images, final_prompt, device, progress_bar) | |
st.session_state.results = results | |
# Save to history | |
save_session_history(results) | |
progress_bar.progress(1.0, text="Processing complete!") | |
# Display results | |
st.subheader("Processing Results") | |
total_time = sum(result["processing_time"] for result in results) | |
avg_time = total_time / len(results) | |
st.write(f"Total processing time: {total_time:.2f} seconds on {selected_device}") | |
st.write(f"Average processing time: {avg_time:.2f} seconds per image") | |
# Create tabs for each image | |
for idx, (result, image) in enumerate(zip(results, images)): | |
with st.expander(f"Image {idx+1} Results"): | |
col1, col2 = st.columns([1, 2]) | |
with col1: | |
st.image(image, caption=f"Image {idx+1}", width=250) | |
st.write(f"Processing time: {result['processing_time']:.2f} seconds") | |
with col2: | |
inner_tabs = st.tabs(["Markdown", "Text", "DocTags", "HTML"]) | |
with inner_tabs[0]: | |
st.markdown(result["markdown"]) | |
st.download_button( | |
f"Download Markdown", | |
result["markdown"], | |
file_name=f"output_{idx+1}.md" | |
) | |
with inner_tabs[1]: | |
st.text_area("Plain Text", result["text"], height=200) | |
st.download_button( | |
f"Download Text", | |
result["text"], | |
file_name=f"output_{idx+1}.txt" | |
) | |
with inner_tabs[2]: | |
st.text_area("DocTags", result["doctags"], height=200) | |
st.download_button( | |
f"Download DocTags", | |
result["doctags"], | |
file_name=f"output_{idx+1}.dt" | |
) | |
with inner_tabs[3]: | |
st.code(result["html"], language="html") | |
st.download_button( | |
f"Download HTML", | |
result["html"], | |
file_name=f"output_{idx+1}.html" | |
) | |
st.success(f"All images processed successfully") | |
except Exception as e: | |
st.error(f"Error processing images: {str(e)}") | |
logger.error(f"Error processing images: {str(e)}", exc_info=True) | |
# Display a welcome message if no image has been uploaded | |
if ('uploaded_file' not in locals() or uploaded_file is None) and \ | |
('uploaded_files' not in locals() or not uploaded_files): | |
st.info("π Upload an image using the sidebar to get started") | |
if __name__ == "__main__": | |
main() |