ocr_captcha / app.py
Nischay103's picture
Update app.py
47ddc95 verified
raw
history blame
3.84 kB
import os
import cv2
import torch
import numpy as np
from PIL import Image
import torchvision.transforms as T
from torch.autograd import Variable
from huggingface_hub import hf_hub_download
import streamlit as st
TOKEN = os.getenv('hf_read_token')
repo_id = "Nischay103/captcha_recognition"
model_files = {
"v1": "captcha_model_v1.pt", "v2": "captcha_model_v2.pt",
"v3": "captcha_model_v3.pt", "v4": "captcha_model_v4.pt",
"v5": "captcha_model_v5.pt", "v6": "captcha_model_v6.pt",
"v7": "captcha_model_v7.pt", "v8": "captcha_model_v8.pt",
}
example_captchas = {
"v1": "v1/v1_MYCbSs.jpg", "v2": "v2/v2_032891.png", "v3": "v3/v3_5Bg5m.png",
"v4": "v4/v4_e78d97.jpg", "v5": "v5/v5_DPMTZ.png", "v6": "v6/v6_WD7A.png",
"v7": "v7/v7_897375.png", "v8": "v8/v8_qxv6x.png",
}
_decode_cls_dims = {
"63": 63, "37_uppercase": 37,
"11": 11, "37_lowercase": 37
}
len_dim_pair = {
"v1": (6, "63"), "v2": (6, "11"), "v3": (5, "63"),
"v4": (6, "37_lowercase"), "v5": (5, "37_uppercase"),
"v6": (4, "63"), "v7": (6, "11"), "v8": (5, "37_uppercase")
}
char_sets = {
"63": "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789$",
"37_uppercase": "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789$",
"37_lowercase": "abcdefghijklmnopqrstuvwxyz0123456789$",
"11": "0123456789$"
}
models = {}
for key, model_file in model_files.items():
model_path = hf_hub_download(repo_id=repo_id, filename=model_file, token=TOKEN)
models[key] = torch.jit.load(model_path)
def transform_image(image_path):
transform = T.Compose([T.ToTensor()])
device = 'cuda' if torch.cuda.is_available() else 'cpu'
image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
image = Image.fromarray(image)
image = image.resize((200, 50))
image = transform(image)
image = Variable(image).to(device)
image = image.unsqueeze(1)
return image
def get_label(model_prediction, model_version):
max_captcha_len, cls_dim_encoded = len_dim_pair[model_version]
_cls = char_sets[cls_dim_encoded]
cls_dim = _decode_cls_dims[cls_dim_encoded]
lab = ""
for idx in range(max_captcha_len):
start = cls_dim * idx
end = cls_dim * (idx + 1)
get_char = _cls[torch.argmax(model_prediction[0, start:end])]
lab += get_char
return lab
st.title("char-seq recognition from captcha")
st.write("recognize captchas using different models")
uploaded_file = st.file_uploader("choose a captcha image...", type=["jpg", "png"])
model_version = st.selectbox("model variant", list(model_files.keys()), index=0)
if uploaded_file is not None:
with open("temp_captcha_image.png", "wb") as f:
f.write(uploaded_file.getbuffer())
input_image_path = "temp_captcha_image.png"
st.image(input_image_path, caption='uploaded captcha image', use_column_width=True)
if st.button('recognize'):
input = transform_image(input_image_path)
model = models[model_version]
with torch.no_grad():
model_prediction = model(input)
output = get_label(model_prediction, model_version)
st.write(f"Recognized Character Sequence: {output}")
st.write("## examples")
__placeholder__ = st.empty()
cols = st.columns(4)
for idx,(model_variant,captcha_path) in enumerate(example_captchas.items()):
with cols[idx % 4]:
st.image(captcha_path, caption=f'{captcha_path.split("/")[-1]}', width=160)
if st.button(f"Recognize with {model_variant}", key=f"button_{idx}"):
input = transform_image(captcha_path)
model = models[model_variant]
with torch.no_grad():
model_prediction = model(input)
output = get_label(model_prediction, model_variant)
__placeholder__.write(f"Model {model_variant} recognized: {output}")