lamiaaEl commited on
Commit
7ee0650
·
verified ·
1 Parent(s): ce5dcb6

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor
3
+ from PIL import Image
4
+ import torch
5
+ import easyocr
6
+ import re
7
+
8
+ # Load the LayoutLMv3 model and processor
9
+ model_name = "your-username/your-model-name" # Replace with your model repository name
10
+ model = LayoutLMv3ForTokenClassification.from_pretrained(model_name)
11
+ processor = LayoutLMv3Processor.from_pretrained(model_name)
12
+
13
+ # Initialize EasyOCR reader for multiple languages
14
+ languages = ["ru", "rs_cyrillic", "be", "bg", "uk", "mn", "en"]
15
+ reader = easyocr.Reader(languages)
16
+
17
+ st.title("LayoutLMv3 and EasyOCR Text Extraction")
18
+ st.write("Upload an image to get text predictions using the fine-tuned LayoutLMv3 model and EasyOCR.")
19
+
20
+ uploaded_file = st.file_uploader("Choose an image...", type="png")
21
+
22
+ if uploaded_file is not None:
23
+ image = Image.open(uploaded_file)
24
+ st.image(image, caption='Uploaded Image.', use_column_width=True)
25
+ st.write("")
26
+ st.write("Classifying...")
27
+
28
+ # Perform text detection with EasyOCR
29
+ ocr_results = reader.readtext(uploaded_file, detail=1)
30
+
31
+ words = []
32
+ boxes = []
33
+
34
+ # Define a regular expression pattern for non-alphabetic characters
35
+ non_alphabet_pattern = re.compile(r'[^a-zA-Z]+')
36
+
37
+ for result in ocr_results:
38
+ bbox, text, _ = result
39
+ filtered_text = re.sub(non_alphabet_pattern, '', text)
40
+ if filtered_text: # Only append if there are alphabetic characters left
41
+ words.append(filtered_text)
42
+ boxes.append([
43
+ bbox[0][0], bbox[0][1],
44
+ bbox[2][0], bbox[2][1]
45
+ ])
46
+
47
+ # Convert to layoutlmv3 format
48
+ encoding = processor(image, words=words, boxes=boxes, return_tensors="pt")
49
+
50
+ # Perform inference with LayoutLMv3
51
+ with torch.no_grad():
52
+ outputs = model(**encoding)
53
+
54
+ logits = outputs.logits
55
+ predictions = logits.argmax(-1).squeeze().cpu().tolist()
56
+ labels = encoding['labels'].squeeze().tolist()
57
+
58
+ # Unnormalize bounding boxes
59
+ def unnormalize_box(bbox, width, height):
60
+ return [
61
+ width * (bbox[0] / 1000),
62
+ height * (bbox[1] / 1000),
63
+ width * (bbox[2] / 1000),
64
+ height * (bbox[3] / 1000),
65
+ ]
66
+
67
+ width, height = image.size
68
+ token_boxes = encoding["bbox"].squeeze().tolist()
69
+
70
+ true_predictions = [model.config.id2label[pred] for pred, label in zip(predictions, labels) if label != -100]
71
+ true_labels = [model.config.id2label[label] for label in labels if label != -100]
72
+ true_boxes = [unnormalize_box(box, width, height) for box, label in zip(token_boxes, labels) if label != -100]
73
+ true_tokens = words
74
+
75
+ # Display results
76
+ st.write("Predicted labels:")
77
+ for word, box, pred in zip(true_tokens, true_boxes, true_predictions):
78
+ st.write(f"Word: {word}, Box: {box}, Prediction: {pred}")
79
+
80
+ # Associate languages with their levels
81
+ languages_with_levels = {}
82
+ current_language = None
83
+
84
+ j = 0
85
+ for i in range(len(true_labels)):
86
+ if true_labels[i] == 'language':
87
+ current_language = true_tokens[j]
88
+ j += 1
89
+ if i + 1 < len(true_labels):
90
+ languages_with_levels[current_language] = true_labels[i + 1]
91
+
92
+ st.write("Languages and Levels:")
93
+ for language, level in languages_with_levels.items():
94
+ st.write(f"{language}: {level}")