File size: 5,130 Bytes
b9dea2c
493daef
 
 
 
ee94965
f51fa47
493daef
 
 
f51fa47
 
 
493daef
 
 
 
 
 
 
f51fa47
 
493daef
f51fa47
 
493daef
 
 
 
f51fa47
493daef
 
 
 
f51fa47
5d72698
 
 
 
 
 
 
 
 
 
493daef
 
 
 
9a82062
493daef
9a82062
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493daef
 
5d72698
f51fa47
493daef
 
 
 
 
 
 
 
9a82062
493daef
f51fa47
493daef
f51fa47
493daef
f51fa47
493daef
 
f51fa47
493daef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f51fa47
 
493daef
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr
import torch
import logging
import os
import json
from PIL import Image
from surya.ocr import run_ocr
from surya.detection import batch_text_detection
from surya.layout import batch_layout_detection
from surya.ordering import batch_ordering
from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
from surya.model.recognition.model import load_model as load_rec_model
from surya.model.recognition.processor import load_processor as load_rec_processor
from surya.model.ordering.model import load_model as load_order_model
from surya.model.ordering.processor import load_processor as load_order_processor
from surya.settings import settings

# Set up logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Load models and processors
logger.info("Loading models and processors...")
det_processor, det_model = load_det_processor(), load_det_model()
rec_model, rec_processor = load_rec_model(), load_rec_processor()
layout_model = load_det_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
layout_processor = load_det_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
order_model = load_order_model()
order_processor = load_order_processor()

# Compile the OCR model for better performance
logger.info("Compiling OCR model...")
os.environ['RECOGNITION_STATIC_CACHE'] = 'true'
rec_model.decoder.model = torch.compile(rec_model.decoder.model)

class SuryaJSONEncoder(json.JSONEncoder):
    def default(self, obj):
        if hasattr(obj, '__dict__'):
            return {key: self.default(value) for key, value in obj.__dict__.items()}
        elif isinstance(obj, (list, tuple)):
            return [self.default(item) for item in obj]
        elif isinstance(obj, Image.Image):
            return "PIL.Image.Image object"
        return super().default(obj)

def process_image(image_path, langs):
    logger.info(f"Processing image: {image_path}")
    image = Image.open(image_path)
    
    results = {}
    
    try:
        # OCR
        logger.info("Performing OCR...")
        ocr_predictions = run_ocr([image], [langs.split(',')], det_model, det_processor, rec_model, rec_processor)
        results["ocr"] = ocr_predictions[0]
        
        # Text line detection
        logger.info("Detecting text lines...")
        line_predictions = batch_text_detection([image], det_model, det_processor)
        results["text_lines"] = line_predictions[0]
        
        # Layout analysis
        logger.info("Analyzing layout...")
        layout_predictions = batch_layout_detection([image], layout_model, layout_processor, line_predictions)
        results["layout"] = layout_predictions[0]
        
        # Reading order
        logger.info("Determining reading order...")
        logger.debug(f"Layout predictions: {layout_predictions}")
        
        if isinstance(layout_predictions[0], dict) and 'bboxes' in layout_predictions[0]:
            bboxes = [bbox['bbox'] for bbox in layout_predictions[0]['bboxes']]
            order_predictions = batch_ordering([image], [bboxes], order_model, order_processor)
            results["reading_order"] = order_predictions[0]
        else:
            logger.warning("Layout predictions do not have the expected structure. Skipping reading order detection.")
            results["reading_order"] = "Reading order detection skipped due to unexpected layout prediction structure."
        
    except Exception as e:
        logger.error(f"Error processing image: {str(e)}", exc_info=True)
        results["error"] = str(e)
    
    logger.info("Processing complete.")
    return json.dumps(results, indent=2, cls=SuryaJSONEncoder)

def surya_ui(image, langs):
    if image is None:
        return "Please upload an image."
    
    try:
        result = process_image(image, langs)
        return result
    except Exception as e:
        logger.error(f"Error in UI processing: {str(e)}", exc_info=True)
        return f"An error occurred: {str(e)}"

# Create Gradio interface
iface = gr.Interface(
    fn=surya_ui,
    inputs=[
        gr.Image(type="filepath", label="Upload Image"),
        gr.Textbox(label="Languages (comma-separated, e.g., 'en,fr')", value="en")
    ],
    outputs=gr.Textbox(label="Results"),
    title="Surya Document Analysis",
    description="Upload an image to perform OCR, text line detection, layout analysis, and reading order detection.",
    theme="huggingface",
    css="""
    .gradio-container {
        font-family: 'IBM Plex Sans', sans-serif;
    }
    .gr-button {
        color: white;
        border-radius: 8px;
        background: linear-gradient(45deg, #ff9a9e 0%, #fad0c4 99%, #fad0c4 100%);
    }
    .gr-button:hover {
        background: linear-gradient(45deg, #fad0c4 0%, #ff9a9e 99%, #ff9a9e 100%);
    }
    .gr-form {
        border-radius: 12px;
        background-color: #ffffff;
        box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
    }
    """
)

# Launch the interface
if __name__ == "__main__":
    logger.info("Starting Gradio interface...")
    iface.launch()