Spaces:
Running
on
T4
Running
on
T4
File size: 5,130 Bytes
b9dea2c 493daef ee94965 f51fa47 493daef f51fa47 493daef f51fa47 493daef f51fa47 493daef f51fa47 493daef f51fa47 5d72698 493daef 9a82062 493daef 9a82062 493daef 5d72698 f51fa47 493daef 9a82062 493daef f51fa47 493daef f51fa47 493daef f51fa47 493daef f51fa47 493daef f51fa47 493daef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import gradio as gr
import torch
import logging
import os
import json
from PIL import Image
from surya.ocr import run_ocr
from surya.detection import batch_text_detection
from surya.layout import batch_layout_detection
from surya.ordering import batch_ordering
from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
from surya.model.recognition.model import load_model as load_rec_model
from surya.model.recognition.processor import load_processor as load_rec_processor
from surya.model.ordering.model import load_model as load_order_model
from surya.model.ordering.processor import load_processor as load_order_processor
from surya.settings import settings
# Set up logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Load models and processors
logger.info("Loading models and processors...")
det_processor, det_model = load_det_processor(), load_det_model()
rec_model, rec_processor = load_rec_model(), load_rec_processor()
layout_model = load_det_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
layout_processor = load_det_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
order_model = load_order_model()
order_processor = load_order_processor()
# Compile the OCR model for better performance
logger.info("Compiling OCR model...")
os.environ['RECOGNITION_STATIC_CACHE'] = 'true'
rec_model.decoder.model = torch.compile(rec_model.decoder.model)
class SuryaJSONEncoder(json.JSONEncoder):
def default(self, obj):
if hasattr(obj, '__dict__'):
return {key: self.default(value) for key, value in obj.__dict__.items()}
elif isinstance(obj, (list, tuple)):
return [self.default(item) for item in obj]
elif isinstance(obj, Image.Image):
return "PIL.Image.Image object"
return super().default(obj)
def process_image(image_path, langs):
logger.info(f"Processing image: {image_path}")
image = Image.open(image_path)
results = {}
try:
# OCR
logger.info("Performing OCR...")
ocr_predictions = run_ocr([image], [langs.split(',')], det_model, det_processor, rec_model, rec_processor)
results["ocr"] = ocr_predictions[0]
# Text line detection
logger.info("Detecting text lines...")
line_predictions = batch_text_detection([image], det_model, det_processor)
results["text_lines"] = line_predictions[0]
# Layout analysis
logger.info("Analyzing layout...")
layout_predictions = batch_layout_detection([image], layout_model, layout_processor, line_predictions)
results["layout"] = layout_predictions[0]
# Reading order
logger.info("Determining reading order...")
logger.debug(f"Layout predictions: {layout_predictions}")
if isinstance(layout_predictions[0], dict) and 'bboxes' in layout_predictions[0]:
bboxes = [bbox['bbox'] for bbox in layout_predictions[0]['bboxes']]
order_predictions = batch_ordering([image], [bboxes], order_model, order_processor)
results["reading_order"] = order_predictions[0]
else:
logger.warning("Layout predictions do not have the expected structure. Skipping reading order detection.")
results["reading_order"] = "Reading order detection skipped due to unexpected layout prediction structure."
except Exception as e:
logger.error(f"Error processing image: {str(e)}", exc_info=True)
results["error"] = str(e)
logger.info("Processing complete.")
return json.dumps(results, indent=2, cls=SuryaJSONEncoder)
def surya_ui(image, langs):
if image is None:
return "Please upload an image."
try:
result = process_image(image, langs)
return result
except Exception as e:
logger.error(f"Error in UI processing: {str(e)}", exc_info=True)
return f"An error occurred: {str(e)}"
# Create Gradio interface
iface = gr.Interface(
fn=surya_ui,
inputs=[
gr.Image(type="filepath", label="Upload Image"),
gr.Textbox(label="Languages (comma-separated, e.g., 'en,fr')", value="en")
],
outputs=gr.Textbox(label="Results"),
title="Surya Document Analysis",
description="Upload an image to perform OCR, text line detection, layout analysis, and reading order detection.",
theme="huggingface",
css="""
.gradio-container {
font-family: 'IBM Plex Sans', sans-serif;
}
.gr-button {
color: white;
border-radius: 8px;
background: linear-gradient(45deg, #ff9a9e 0%, #fad0c4 99%, #fad0c4 100%);
}
.gr-button:hover {
background: linear-gradient(45deg, #fad0c4 0%, #ff9a9e 99%, #ff9a9e 100%);
}
.gr-form {
border-radius: 12px;
background-color: #ffffff;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
"""
)
# Launch the interface
if __name__ == "__main__":
logger.info("Starting Gradio interface...")
iface.launch()
|