|
import streamlit as st |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torchvision import transforms |
|
from PIL import Image |
|
from pathlib import Path |
|
import pickle |
|
|
|
transform = transforms.Compose([ |
|
transforms.ToTensor() |
|
]) |
|
|
|
class TextProcessor: |
|
def __init__(self, alphabet): |
|
self.alphabet = alphabet |
|
self.pad_token = "[PAD]" |
|
self.stoi = {s: i for i, s in enumerate(self.alphabet,1)} |
|
self.stoi[self.pad_token] = 0 |
|
self.itos = {i: s for s, i in self.stoi.items()} |
|
|
|
def encode(self, label): |
|
return [self.stoi[s] for s in label] |
|
|
|
def decode(self, ids): |
|
return ''.join([self.itos[i] for i in ids]) |
|
|
|
def __len__(self): |
|
return len(self.alphabet) + 1 |
|
|
|
MAX_LENGTH = 32 |
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
@st.cache_resource |
|
def load_tokenizer(selected_model): |
|
if "large" in selected_model.parts[-1]: |
|
text_processor_path = "text_process-large.cls" |
|
else: |
|
text_processor_path = "text_process.cls" |
|
with open(text_processor_path,'rb') as f: |
|
tokenizer = pickle.load(f) |
|
return tokenizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CRNN(nn.Module): |
|
def __init__(self, num_channels, hidden_size, num_classes): |
|
super(CRNN, self).__init__() |
|
|
|
self.conv1 = nn.Sequential( |
|
nn.Conv2d(1, 64, kernel_size=(2,3), padding=1), |
|
nn.ReLU(), |
|
nn.MaxPool2d(2, 2) |
|
) |
|
|
|
self.conv2 = nn.Sequential( |
|
nn.Conv2d(64, 128, kernel_size=(2,3), padding=1), |
|
nn.ReLU(), |
|
nn.MaxPool2d(2, 2) |
|
) |
|
|
|
self.rnn = nn.LSTM(128 * 16, hidden_size, bidirectional=True, batch_first=True) |
|
|
|
self.fc = nn.Linear(hidden_size * 2, num_classes) |
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
conv = self.conv1(x) |
|
conv = self.conv2(conv) |
|
batch, channels, height, width = conv.size() |
|
|
|
conv = conv.permute(0, 3, 1, 2) |
|
conv = conv.contiguous().view(batch, width, channels * height) |
|
|
|
rnn, _ = self.rnn(conv) |
|
|
|
output = self.fc(rnn) |
|
|
|
return output |
|
|
|
|
|
@st.cache_resource |
|
def load_model(selected_model_path): |
|
model = CRNN(num_channels=1, hidden_size=256, num_classes=len(tokenizer)) |
|
model.load_state_dict(torch.load(selected_model_path, map_location=torch.device('cpu'))) |
|
model.eval() |
|
return model |
|
|
|
|
|
def preprocess_image(img): |
|
|
|
original_width, original_height = img.size |
|
new_width = int(61 * original_width / original_height) |
|
image = img.resize((new_width, 61)) |
|
image = transform(image) |
|
return image |
|
|
|
|
|
def post_process(preds): |
|
encodings = [] |
|
is_previous_zero = False |
|
for pred in preds: |
|
|
|
if pred==0: |
|
zero_found = True |
|
pass |
|
elif not encodings: |
|
encodings.append(pred) |
|
elif encodings[-1] != pred: |
|
encodings.append(pred) |
|
return decode(encodings) |
|
|
|
|
|
def inference(model, image): |
|
with torch.no_grad(): |
|
image = image.to(DEVICE) |
|
outputs = model(image) |
|
log_probs = F.log_softmax(outputs, dim=2) |
|
pred_chars = torch.argmax(log_probs, dim=2) |
|
return pred_chars.squeeze().cpu().numpy() |
|
|
|
def predict(image): |
|
image = preprocess_image(image) |
|
image = image.unsqueeze(0) |
|
predictions = model(image) |
|
pred_ids = torch.argmax(predictions, dim=-1).detach().flatten().tolist() |
|
text = post_process(pred_ids) |
|
return text |
|
|
|
st.title("CRNN Sinhala Printed Text Recognition") |
|
fp = Path(".").glob("crnn*.pt") |
|
selected_model_path = st.selectbox(label="Select Model...", options=fp) |
|
tokenizer = load_tokenizer(selected_model_path) |
|
encode = tokenizer.encode |
|
decode = tokenizer.decode |
|
model = load_model(selected_model_path) |
|
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
|
|
|
|
|
if uploaded_file is not None: |
|
image = Image.open(uploaded_file).convert("L") |
|
st.image(image, caption='Uploaded Image', use_column_width=True) |
|
|
|
if st.button('Predict'): |
|
predicted_text = predict(image) |
|
st.write("Predicted Text:") |
|
st.write(predicted_text) |
|
|
|
st.markdown("---") |
|
st.write("Note: This app uses a pre-trained CRNN model for printed Sinhala text recognition.") |