Spaces:
Runtime error
Runtime error
File size: 3,463 Bytes
7ee0650 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import streamlit as st
from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor
from PIL import Image
import torch
import easyocr
import re
# Load the LayoutLMv3 model and processor
model_name = "your-username/your-model-name" # Replace with your model repository name
model = LayoutLMv3ForTokenClassification.from_pretrained(model_name)
processor = LayoutLMv3Processor.from_pretrained(model_name)
# Initialize EasyOCR reader for multiple languages
languages = ["ru", "rs_cyrillic", "be", "bg", "uk", "mn", "en"]
reader = easyocr.Reader(languages)
st.title("LayoutLMv3 and EasyOCR Text Extraction")
st.write("Upload an image to get text predictions using the fine-tuned LayoutLMv3 model and EasyOCR.")
uploaded_file = st.file_uploader("Choose an image...", type="png")
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption='Uploaded Image.', use_column_width=True)
st.write("")
st.write("Classifying...")
# Perform text detection with EasyOCR
ocr_results = reader.readtext(uploaded_file, detail=1)
words = []
boxes = []
# Define a regular expression pattern for non-alphabetic characters
non_alphabet_pattern = re.compile(r'[^a-zA-Z]+')
for result in ocr_results:
bbox, text, _ = result
filtered_text = re.sub(non_alphabet_pattern, '', text)
if filtered_text: # Only append if there are alphabetic characters left
words.append(filtered_text)
boxes.append([
bbox[0][0], bbox[0][1],
bbox[2][0], bbox[2][1]
])
# Convert to layoutlmv3 format
encoding = processor(image, words=words, boxes=boxes, return_tensors="pt")
# Perform inference with LayoutLMv3
with torch.no_grad():
outputs = model(**encoding)
logits = outputs.logits
predictions = logits.argmax(-1).squeeze().cpu().tolist()
labels = encoding['labels'].squeeze().tolist()
# Unnormalize bounding boxes
def unnormalize_box(bbox, width, height):
return [
width * (bbox[0] / 1000),
height * (bbox[1] / 1000),
width * (bbox[2] / 1000),
height * (bbox[3] / 1000),
]
width, height = image.size
token_boxes = encoding["bbox"].squeeze().tolist()
true_predictions = [model.config.id2label[pred] for pred, label in zip(predictions, labels) if label != -100]
true_labels = [model.config.id2label[label] for label in labels if label != -100]
true_boxes = [unnormalize_box(box, width, height) for box, label in zip(token_boxes, labels) if label != -100]
true_tokens = words
# Display results
st.write("Predicted labels:")
for word, box, pred in zip(true_tokens, true_boxes, true_predictions):
st.write(f"Word: {word}, Box: {box}, Prediction: {pred}")
# Associate languages with their levels
languages_with_levels = {}
current_language = None
j = 0
for i in range(len(true_labels)):
if true_labels[i] == 'language':
current_language = true_tokens[j]
j += 1
if i + 1 < len(true_labels):
languages_with_levels[current_language] = true_labels[i + 1]
st.write("Languages and Levels:")
for language, level in languages_with_levels.items():
st.write(f"{language}: {level}")
|