artificialguybr commited on
Commit
493daef
·
verified ·
1 Parent(s): f51fa47

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -26
app.py CHANGED
@@ -1,45 +1,113 @@
1
  import gradio as gr
 
 
 
 
2
  from PIL import Image
3
- import io
4
  from surya.ocr import run_ocr
 
 
 
5
  from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
6
  from surya.model.recognition.model import load_model as load_rec_model
7
  from surya.model.recognition.processor import load_processor as load_rec_processor
 
 
 
 
 
 
 
8
 
9
  # Load models and processors
 
10
  det_processor, det_model = load_det_processor(), load_det_model()
11
  rec_model, rec_processor = load_rec_model(), load_rec_processor()
 
 
 
 
12
 
13
- def perform_ocr(image, language):
14
- # Convert gradio image to PIL Image
15
- if image is not None:
16
- image = Image.fromarray(image)
17
- else:
18
- return "No image uploaded"
19
-
20
- # Perform OCR
21
- langs = [language] # You can expand this to support multiple languages
22
- predictions = run_ocr([image], [langs], det_model, det_processor, rec_model, rec_processor)
23
 
24
- # Extract text from predictions
25
- result = ""
26
- for page in predictions[0]: # Assuming single image input
27
- for line in page['text_lines']:
28
- result += line['text'] + "\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- return result
 
 
 
 
 
 
 
 
 
31
 
32
- # Define the Gradio interface
33
  iface = gr.Interface(
34
- fn=perform_ocr,
35
  inputs=[
36
- gr.Image(type="numpy", label="Upload an image"),
37
- gr.Dropdown(choices=["en", "fr", "de", "es", "it"], label="Select language", value="en")
38
  ],
39
- outputs=gr.Textbox(label="Extracted Text"),
40
- title="OCR with Surya",
41
- description="Upload an image to extract text using Optical Character Recognition."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  )
43
 
44
- # Launch the app
45
- iface.launch()
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import logging
4
+ import os
5
+ import json
6
  from PIL import Image
 
7
  from surya.ocr import run_ocr
8
+ from surya.detection import batch_text_detection
9
+ from surya.layout import batch_layout_detection
10
+ from surya.ordering import batch_ordering
11
  from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
12
  from surya.model.recognition.model import load_model as load_rec_model
13
  from surya.model.recognition.processor import load_processor as load_rec_processor
14
+ from surya.model.ordering.model import load_model as load_order_model
15
+ from surya.model.ordering.processor import load_processor as load_order_processor
16
+ from surya.settings import settings
17
+
18
+ # Set up logging
19
+ logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
20
+ logger = logging.getLogger(__name__)
21
 
22
  # Load models and processors
23
+ logger.info("Loading models and processors...")
24
  det_processor, det_model = load_det_processor(), load_det_model()
25
  rec_model, rec_processor = load_rec_model(), load_rec_processor()
26
+ layout_model = load_det_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
27
+ layout_processor = load_det_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
28
+ order_model = load_order_model()
29
+ order_processor = load_order_processor()
30
 
31
+ # Compile the OCR model for better performance
32
+ logger.info("Compiling OCR model...")
33
+ os.environ['RECOGNITION_STATIC_CACHE'] = 'true'
34
+ rec_model.decoder.model = torch.compile(rec_model.decoder.model)
 
 
 
 
 
 
35
 
36
+ def process_image(image_path, langs):
37
+ logger.info(f"Processing image: {image_path}")
38
+ image = Image.open(image_path)
39
+
40
+ # OCR
41
+ logger.info("Performing OCR...")
42
+ ocr_predictions = run_ocr([image], [langs.split(',')], det_model, det_processor, rec_model, rec_processor)
43
+
44
+ # Text line detection
45
+ logger.info("Detecting text lines...")
46
+ line_predictions = batch_text_detection([image], det_model, det_processor)
47
+
48
+ # Layout analysis
49
+ logger.info("Analyzing layout...")
50
+ layout_predictions = batch_layout_detection([image], layout_model, layout_processor, line_predictions)
51
+
52
+ # Reading order
53
+ logger.info("Determining reading order...")
54
+ bboxes = [bbox['bbox'] for bbox in layout_predictions[0]['bboxes']]
55
+ order_predictions = batch_ordering([image], [bboxes], order_model, order_processor)
56
+
57
+ # Combine results
58
+ results = {
59
+ "ocr": ocr_predictions[0],
60
+ "text_lines": line_predictions[0],
61
+ "layout": layout_predictions[0],
62
+ "reading_order": order_predictions[0]
63
+ }
64
+
65
+ logger.info("Processing complete.")
66
+ return json.dumps(results, indent=2)
67
 
68
+ def surya_ui(image, langs):
69
+ if image is None:
70
+ return "Please upload an image."
71
+
72
+ try:
73
+ result = process_image(image, langs)
74
+ return result
75
+ except Exception as e:
76
+ logger.error(f"Error processing image: {str(e)}")
77
+ return f"An error occurred: {str(e)}"
78
 
79
+ # Create Gradio interface
80
  iface = gr.Interface(
81
+ fn=surya_ui,
82
  inputs=[
83
+ gr.Image(type="filepath", label="Upload Image"),
84
+ gr.Textbox(label="Languages (comma-separated, e.g., 'en,fr')", value="en")
85
  ],
86
+ outputs=gr.Textbox(label="Results"),
87
+ title="Surya Document Analysis",
88
+ description="Upload an image to perform OCR, text line detection, layout analysis, and reading order detection.",
89
+ theme="huggingface",
90
+ css="""
91
+ .gradio-container {
92
+ font-family: 'IBM Plex Sans', sans-serif;
93
+ }
94
+ .gr-button {
95
+ color: white;
96
+ border-radius: 8px;
97
+ background: linear-gradient(45deg, #ff9a9e 0%, #fad0c4 99%, #fad0c4 100%);
98
+ }
99
+ .gr-button:hover {
100
+ background: linear-gradient(45deg, #fad0c4 0%, #ff9a9e 99%, #ff9a9e 100%);
101
+ }
102
+ .gr-form {
103
+ border-radius: 12px;
104
+ background-color: #ffffff;
105
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
106
+ }
107
+ """
108
  )
109
 
110
+ # Launch the interface
111
+ if __name__ == "__main__":
112
+ logger.info("Starting Gradio interface...")
113
+ iface.launch()