lamiaaEl commited on
Commit
84c4327
·
verified ·
1 Parent(s): b1ca1bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -94
app.py CHANGED
@@ -1,94 +1,76 @@
1
- import streamlit as st
2
- from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor
3
- from PIL import Image
4
- import torch
5
- import easyocr
6
- import re
7
-
8
- # Load the LayoutLMv3 model and processor
9
- model_name = "your-username/your-model-name" # Replace with your model repository name
10
- model = LayoutLMv3ForTokenClassification.from_pretrained(model_name)
11
- processor = LayoutLMv3Processor.from_pretrained(model_name)
12
-
13
- # Initialize EasyOCR reader for multiple languages
14
- languages = ["ru", "rs_cyrillic", "be", "bg", "uk", "mn", "en"]
15
- reader = easyocr.Reader(languages)
16
-
17
- st.title("LayoutLMv3 and EasyOCR Text Extraction")
18
- st.write("Upload an image to get text predictions using the fine-tuned LayoutLMv3 model and EasyOCR.")
19
-
20
- uploaded_file = st.file_uploader("Choose an image...", type="png")
21
-
22
- if uploaded_file is not None:
23
- image = Image.open(uploaded_file)
24
- st.image(image, caption='Uploaded Image.', use_column_width=True)
25
- st.write("")
26
- st.write("Classifying...")
27
-
28
- # Perform text detection with EasyOCR
29
- ocr_results = reader.readtext(uploaded_file, detail=1)
30
-
31
- words = []
32
- boxes = []
33
-
34
- # Define a regular expression pattern for non-alphabetic characters
35
- non_alphabet_pattern = re.compile(r'[^a-zA-Z]+')
36
-
37
- for result in ocr_results:
38
- bbox, text, _ = result
39
- filtered_text = re.sub(non_alphabet_pattern, '', text)
40
- if filtered_text: # Only append if there are alphabetic characters left
41
- words.append(filtered_text)
42
- boxes.append([
43
- bbox[0][0], bbox[0][1],
44
- bbox[2][0], bbox[2][1]
45
- ])
46
-
47
- # Convert to layoutlmv3 format
48
- encoding = processor(image, words=words, boxes=boxes, return_tensors="pt")
49
-
50
- # Perform inference with LayoutLMv3
51
- with torch.no_grad():
52
- outputs = model(**encoding)
53
-
54
- logits = outputs.logits
55
- predictions = logits.argmax(-1).squeeze().cpu().tolist()
56
- labels = encoding['labels'].squeeze().tolist()
57
-
58
- # Unnormalize bounding boxes
59
- def unnormalize_box(bbox, width, height):
60
- return [
61
- width * (bbox[0] / 1000),
62
- height * (bbox[1] / 1000),
63
- width * (bbox[2] / 1000),
64
- height * (bbox[3] / 1000),
65
- ]
66
-
67
- width, height = image.size
68
- token_boxes = encoding["bbox"].squeeze().tolist()
69
-
70
- true_predictions = [model.config.id2label[pred] for pred, label in zip(predictions, labels) if label != -100]
71
- true_labels = [model.config.id2label[label] for label in labels if label != -100]
72
- true_boxes = [unnormalize_box(box, width, height) for box, label in zip(token_boxes, labels) if label != -100]
73
- true_tokens = words
74
-
75
- # Display results
76
- st.write("Predicted labels:")
77
- for word, box, pred in zip(true_tokens, true_boxes, true_predictions):
78
- st.write(f"Word: {word}, Box: {box}, Prediction: {pred}")
79
-
80
- # Associate languages with their levels
81
- languages_with_levels = {}
82
- current_language = None
83
-
84
- j = 0
85
- for i in range(len(true_labels)):
86
- if true_labels[i] == 'language':
87
- current_language = true_tokens[j]
88
- j += 1
89
- if i + 1 < len(true_labels):
90
- languages_with_levels[current_language] = true_labels[i + 1]
91
-
92
- st.write("Languages and Levels:")
93
- for language, level in languages_with_levels.items():
94
- st.write(f"{language}: {level}")
 
1
+ import os
2
+ import numpy as np
3
+ import streamlit as st
4
+ from transformers import AutoModelForTokenClassification, AutoProcessor
5
+ from PIL import Image, ImageDraw, ImageFont
6
+
7
+ # Load the LayoutLMv3 model and processor
8
+ processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=True)
9
+ model = AutoModelForTokenClassification.from_pretrained("capitaletech/language-levels-LayoutLMv3-v4")
10
+
11
+ labels = ["language", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10"]
12
+
13
+ label2id = {label: idx for idx, label in enumerate(labels)}
14
+ id2label = {v: k for k, v in label2id.items()}
15
+ label2color = {
16
+ 'language': 'blue', '1': 'red', '2': 'red', '3': 'red',
17
+ '4': 'orange', '5': 'orange', '6': 'orange', '7': 'green',
18
+ '8': 'green', '9': 'green', '10': 'green'
19
+ }
20
+
21
+ def unnormalize_box(bbox, width, height):
22
+ return [
23
+ width * (bbox[0] / 1000),
24
+ height * (bbox[1] / 1000),
25
+ width * (bbox[2] / 1000),
26
+ height * (bbox[3] / 1000),
27
+ ]
28
+
29
+ def iob_to_label(label):
30
+ return label
31
+
32
+ def process_image(image):
33
+ width, height = image.size
34
+
35
+ # Encode
36
+ encoding = processor(image, truncation=True, return_offsets_mapping=True, return_tensors="pt")
37
+ offset_mapping = encoding.pop('offset_mapping')
38
+
39
+ # Forward pass
40
+ outputs = model(**encoding)
41
+
42
+ # Get predictions
43
+ predictions = outputs.logits.argmax(-1).squeeze().tolist()
44
+ token_boxes = encoding.bbox.squeeze().tolist()
45
+
46
+ # Only keep non-subword predictions
47
+ is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0
48
+ true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
49
+ true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
50
+
51
+ # Draw predictions over the image
52
+ draw = ImageDraw.Draw(image)
53
+ font = ImageFont.load_default()
54
+ for prediction, box in zip(true_predictions, true_boxes):
55
+ predicted_label = iob_to_label(prediction)
56
+ draw.rectangle(box, outline=label2color[predicted_label])
57
+ draw.text((box[0] + 10, box[1] - 10), text=predicted_label, fill=label2color[predicted_label], font=font)
58
+
59
+ return image
60
+
61
+ # Streamlit UI
62
+ st.title("Language Levels Extraction using LayoutLMv3 Model")
63
+ st.write("Use this application to predict language levels in CVs.")
64
+
65
+ uploaded_file = st.file_uploader("Choose an image...", type="png")
66
+
67
+ if uploaded_file is not None:
68
+ image = Image.open(uploaded_file)
69
+ st.image(image, caption='Uploaded Image', use_column_width=True)
70
+
71
+ if st.button('Predict'):
72
+ annotated_image = process_image(image)
73
+ st.image(annotated_image, caption='Annotated Image', use_column_width=True)
74
+
75
+ # Add your token if required
76
+ # os.environ["YOUR_TOKEN_ENV_VAR"] = "your_token_here"