File size: 4,353 Bytes
7f200cc
 
 
 
 
 
93ea391
232505b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f200cc
 
 
 
 
 
 
232505b
 
7f200cc
 
 
232505b
 
7f200cc
 
232505b
7f200cc
 
232505b
 
7f200cc
232505b
7f200cc
 
232505b
 
 
 
 
 
 
 
 
7f200cc
 
232505b
 
 
 
 
 
 
 
 
 
 
 
 
 
7f200cc
 
232505b
7f200cc
93ea391
232505b
 
7f200cc
 
 
232505b
 
 
 
 
 
7f200cc
232505b
7f200cc
 
232505b
 
 
 
 
 
 
 
 
 
 
 
 
7f200cc
232505b
7f200cc
 
 
 
 
 
 
 
232505b
 
 
 
 
 
 
 
 
 
93ea391
 
7f200cc
 
 
232505b
7f200cc
 
 
232505b
7f200cc
232505b
7f200cc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import streamlit as st
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from pathlib import Path
import pickle

transform = transforms.Compose([
            transforms.ToTensor()
            ])

class TextProcessor:
    def __init__(self, alphabet):
        self.alphabet = alphabet
        self.pad_token = "[PAD]"
        self.stoi = {s: i for i, s in enumerate(self.alphabet,1)}
        self.stoi[self.pad_token] = 0
        self.itos = {i: s for s, i in self.stoi.items()}
        
    def encode(self, label):
        return [self.stoi[s] for s in label]
    
    def decode(self, ids):
        return ''.join([self.itos[i] for i in ids])
    
    def __len__(self):
        return len(self.alphabet) + 1

MAX_LENGTH = 32
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load tokenizer
@st.cache_resource
def load_tokenizer():
    with open("text_process.cls",'rb') as f:
        tokenizer = pickle.load(f)
    return tokenizer

tokenizer = load_tokenizer()
encode = tokenizer.encode
decode = tokenizer.decode

class CRNN(nn.Module):
    def __init__(self, num_channels, hidden_size, num_classes):
        super(CRNN, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(2,3), padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=(2,3), padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )

        self.rnn = nn.LSTM(128 * 16, hidden_size, bidirectional=True, batch_first=True)

        self.fc = nn.Linear(hidden_size * 2, num_classes)

    def forward(self, x):
        # x shape: [batch_size, channels, height, width]

        # CNN feature extraction
        conv = self.conv1(x)
        conv = self.conv2(conv)
        batch, channels, height, width = conv.size()

        conv = conv.permute(0, 3, 1, 2)  # [batch, width, channels, height]
        conv = conv.contiguous().view(batch, width, channels * height)

        rnn, _ = self.rnn(conv)

        output = self.fc(rnn)

        return output


@st.cache_resource
def load_model(selected_model_path):
    model = CRNN(num_channels=1, hidden_size=256, num_classes=len(tokenizer))
    model.load_state_dict(torch.load(selected_model_path, map_location=torch.device('cpu')))
    model.eval()
    return model


def preprocess_image(img):
    # img = image.convert("L")  # Ensuring image is in grayscale
    original_width, original_height = img.size
    new_width = int(61 * original_width / original_height)  # Calculate width to preserve aspect ratio
    image = img.resize((new_width, 61))
    image = transform(image)
    return image


def post_process(preds):
    encodings = []
    is_previous_zero = False
    for pred in preds:
        #only considering >0 tokens
        if pred==0:
            zero_found = True
            pass
        elif not encodings:
            encodings.append(pred)
        elif encodings[-1] != pred:
            encodings.append(pred)
    return decode(encodings)

    
def inference(model, image):
    with torch.no_grad():
        image = image.to(DEVICE)
        outputs = model(image)
        log_probs = F.log_softmax(outputs, dim=2)
        pred_chars = torch.argmax(log_probs, dim=2)
    return pred_chars.squeeze().cpu().numpy()

def predict(image):
    image = preprocess_image(image)
    image = image.unsqueeze(0) #remove batch dim
    predictions = model(image)
    pred_ids = torch.argmax(predictions, dim=-1).detach().flatten().tolist()
    text = post_process(pred_ids)
    return text

st.title("CRNN Sinhala Printed Text Recognition")
fp = Path(".").glob("crnn*.pt")
selected_model_path = st.selectbox(label="Select Model...", options=fp)
model = load_model(selected_model_path)
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

if uploaded_file is not None:
    image = Image.open(uploaded_file).convert("L")
    st.image(image, caption='Uploaded Image', use_column_width=True)
    
    if st.button('Predict'):
        predicted_text = predict(image)
        st.write("Predicted Text:")
        st.write(predicted_text)

st.markdown("---")
st.write("Note: This app uses a pre-trained CRNN model for printed Sinhala text recognition.")