Spaces:
Running
on
T4
Running
on
T4
File size: 4,790 Bytes
b9dea2c 493daef ee94965 94a609f f51fa47 493daef f51fa47 493daef 94a609f 493daef 94a609f 493daef f51fa47 94a609f f51fa47 493daef f51fa47 94a609f 493daef f51fa47 94a609f 5d72698 94a609f f51fa47 94a609f 493daef 94a609f f51fa47 94a609f f51fa47 493daef 94a609f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import gradio as gr
import logging
import os
import json
from PIL import Image
import torch
from surya.ocr import run_ocr
from surya.detection import batch_text_detection
from surya.layout import batch_layout_detection
from surya.ordering import batch_ordering
from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
from surya.model.recognition.model import load_model as load_rec_model
from surya.model.recognition.processor import load_processor as load_rec_processor
from surya.settings import settings
from surya.model.ordering.processor import load_processor as load_order_processor
from surya.model.ordering.model import load_model as load_order_model
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Set environment variables for performance
os.environ["RECOGNITION_BATCH_SIZE"] = "512"
os.environ["DETECTOR_BATCH_SIZE"] = "36"
os.environ["ORDER_BATCH_SIZE"] = "32"
os.environ["RECOGNITION_STATIC_CACHE"] = "true"
# Load models
logger.info("Loading models...")
det_processor, det_model = load_det_processor(), load_det_model()
rec_model, rec_processor = load_rec_model(), load_rec_processor()
layout_model = load_det_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
layout_processor = load_det_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
order_model = load_order_model()
order_processor = load_order_processor()
# Compile recognition model
logger.info("Compiling recognition model...")
rec_model.decoder.model = torch.compile(rec_model.decoder.model)
def ocr_workflow(image, langs):
logger.info(f"Starting OCR workflow with languages: {langs}")
image = Image.open(image.name)
predictions = run_ocr([image], [langs.split(',')], det_model, det_processor, rec_model, rec_processor)
logger.info("OCR workflow completed")
return json.dumps(predictions, indent=2)
def text_detection_workflow(image):
logger.info("Starting text detection workflow")
image = Image.open(image.name)
predictions = batch_text_detection([image], det_model, det_processor)
logger.info("Text detection workflow completed")
return json.dumps(predictions, indent=2)
def layout_analysis_workflow(image):
logger.info("Starting layout analysis workflow")
image = Image.open(image.name)
line_predictions = batch_text_detection([image], det_model, det_processor)
layout_predictions = batch_layout_detection([image], layout_model, layout_processor, line_predictions)
logger.info("Layout analysis workflow completed")
return json.dumps(layout_predictions, indent=2)
def reading_order_workflow(image):
logger.info("Starting reading order workflow")
image = Image.open(image.name)
line_predictions = batch_text_detection([image], det_model, det_processor)
layout_predictions = batch_layout_detection([image], layout_model, layout_processor, line_predictions)
bboxes = [pred['bbox'] for pred in layout_predictions[0]['bboxes']]
order_predictions = batch_ordering([image], [bboxes], order_model, order_processor)
logger.info("Reading order workflow completed")
return json.dumps(order_predictions, indent=2)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# Surya Document Analysis")
with gr.Tab("OCR"):
gr.Markdown("## Optical Character Recognition")
with gr.Row():
ocr_input = gr.File(label="Upload Image or PDF")
ocr_langs = gr.Textbox(label="Languages (comma-separated)", value="en")
ocr_button = gr.Button("Run OCR")
ocr_output = gr.JSON(label="OCR Results")
ocr_button.click(ocr_workflow, inputs=[ocr_input, ocr_langs], outputs=ocr_output)
with gr.Tab("Text Detection"):
gr.Markdown("## Text Line Detection")
det_input = gr.File(label="Upload Image or PDF")
det_button = gr.Button("Run Text Detection")
det_output = gr.JSON(label="Text Detection Results")
det_button.click(text_detection_workflow, inputs=det_input, outputs=det_output)
with gr.Tab("Layout Analysis"):
gr.Markdown("## Layout Analysis and Reading Order")
layout_input = gr.File(label="Upload Image or PDF")
layout_button = gr.Button("Run Layout Analysis")
order_button = gr.Button("Determine Reading Order")
layout_output = gr.JSON(label="Layout Analysis Results")
order_output = gr.JSON(label="Reading Order Results")
layout_button.click(layout_analysis_workflow, inputs=layout_input, outputs=layout_output)
order_button.click(reading_order_workflow, inputs=layout_input, outputs=order_output)
if __name__ == "__main__":
logger.info("Starting Gradio app...")
demo.launch()
|