OCR-CRNN / app.py
Ransaka's picture
Update app.py
7e2c54e verified
raw
history blame
6.16 kB
import streamlit as st
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from pathlib import Path
import pickle
transform = transforms.Compose([
transforms.ToTensor()
])
class TextProcessor:
def __init__(self, alphabet):
self.alphabet = alphabet
self.pad_token = "[PAD]"
self.stoi = {s: i for i, s in enumerate(self.alphabet,1)}
self.stoi[self.pad_token] = 0
self.itos = {i: s for s, i in self.stoi.items()}
def encode(self, label):
return [self.stoi[s] for s in label]
def decode(self, ids):
return ''.join([self.itos[i] for i in ids])
def __len__(self):
return len(self.alphabet) + 1
MAX_LENGTH = 32
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load tokenizer
@st.cache_resource
def load_tokenizer(selected_model):
if "large" in selected_model.parts[-1]:
text_processor_path = "text_process-large.cls"
else:
text_processor_path = "text_process.cls"
with open(text_processor_path,'rb') as f:
tokenizer = pickle.load(f)
return tokenizer
# class CRNN(nn.Module):
# def __init__(self, num_chars):
# super(CRNN, self).__init__()
# self.cnn = nn.Sequential(
# nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
# nn.ReLU(),
# nn.MaxPool2d(kernel_size=2, stride=2),
# nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
# nn.ReLU(),
# nn.MaxPool2d(kernel_size=2, stride=2),
# nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
# nn.BatchNorm2d(256),
# nn.ReLU(),
# nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
# nn.ReLU(),
# nn.MaxPool2d(kernel_size=(2, 1)),
# nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
# nn.BatchNorm2d(512),
# nn.ReLU(),
# nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
# nn.ReLU(),
# nn.MaxPool2d(kernel_size=(2, 1)),
# nn.Conv2d(512, 512, kernel_size=2, stride=1),
# nn.BatchNorm2d(512),
# nn.ReLU()
# )
# # RNN layers
# self.rnn = nn.GRU(512 * 7, 256, bidirectional=True, batch_first=True, num_layers=2)
# self.linear = nn.Linear(512, num_chars)
# def forward(self, x):
# conv = self.cnn(x)
# batch, channel, height, width = conv.size()
# conv = conv.permute(0, 3, 1, 2)
# conv = conv.contiguous().view(batch, width, channel * height)
# output, _ = self.rnn(conv)
# output = self.linear(output)
# return output
class CRNN(nn.Module):
def __init__(self, num_channels, hidden_size, num_classes):
super(CRNN, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=(2,3), padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2)
)
self.conv2 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=(2,3), padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2)
)
self.rnn = nn.LSTM(128 * 16, hidden_size, bidirectional=True, batch_first=True)
self.fc = nn.Linear(hidden_size * 2, num_classes)
def forward(self, x):
# x shape: [batch_size, channels, height, width]
# CNN feature extraction
conv = self.conv1(x)
conv = self.conv2(conv)
batch, channels, height, width = conv.size()
conv = conv.permute(0, 3, 1, 2) # [batch, width, channels, height]
conv = conv.contiguous().view(batch, width, channels * height)
rnn, _ = self.rnn(conv)
output = self.fc(rnn)
return output
@st.cache_resource
def load_model(selected_model_path):
model = CRNN(num_channels=1, hidden_size=256, num_classes=len(tokenizer))
model.load_state_dict(torch.load(selected_model_path, map_location=torch.device('cpu')))
model.eval()
return model
def preprocess_image(img):
# img = image.convert("L") # Ensuring image is in grayscale
original_width, original_height = img.size
new_width = int(61 * original_width / original_height) # Calculate width to preserve aspect ratio
image = img.resize((new_width, 61))
image = transform(image)
return image
def post_process(preds):
encodings = []
is_previous_zero = False
for pred in preds:
#only considering >0 tokens
if pred==0:
zero_found = True
pass
elif not encodings:
encodings.append(pred)
elif encodings[-1] != pred:
encodings.append(pred)
return decode(encodings)
def inference(model, image):
with torch.no_grad():
image = image.to(DEVICE)
outputs = model(image)
log_probs = F.log_softmax(outputs, dim=2)
pred_chars = torch.argmax(log_probs, dim=2)
return pred_chars.squeeze().cpu().numpy()
def predict(image):
image = preprocess_image(image)
image = image.unsqueeze(0) #remove batch dim
predictions = model(image)
pred_ids = torch.argmax(predictions, dim=-1).detach().flatten().tolist()
text = post_process(pred_ids)
return text
st.title("CRNN Sinhala Printed Text Recognition")
fp = Path(".").glob("crnn*.pt")
selected_model_path = st.selectbox(label="Select Model...", options=fp)
tokenizer = load_tokenizer(selected_model_path)
encode = tokenizer.encode
decode = tokenizer.decode
model = load_model(selected_model_path)
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file).convert("L")
st.image(image, caption='Uploaded Image', use_column_width=True)
if st.button('Predict'):
predicted_text = predict(image)
st.write("Predicted Text:")
st.write(predicted_text)
st.markdown("---")
st.write("Note: This app uses a pre-trained CRNN model for printed Sinhala text recognition.")