import os
import torch
import random
import warnings
import gradio as gr
from PIL import Image
from model import Model
from torchvision import transforms
from modelscope import snapshot_download


MODEL_DIR = snapshot_download("Genius-Society/svhn", cache_dir="./__pycache__")


def infer(input_img: str, checkpoint_file: str):
    try:
        model = Model()
        model.restore(f"{MODEL_DIR}/{checkpoint_file}")
        outstr = ""
        with torch.no_grad():
            transform = transforms.Compose(
                [
                    transforms.Resize([64, 64]),
                    transforms.CenterCrop([54, 54]),
                    transforms.ToTensor(),
                    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
                ]
            )
            image = Image.open(input_img)
            image = image.convert("RGB")
            image = transform(image)
            images = image.unsqueeze(dim=0)
            (
                length_logits,
                digit1_logits,
                digit2_logits,
                digit3_logits,
                digit4_logits,
                digit5_logits,
            ) = model.eval()(images)
            length_prediction = length_logits.max(1)[1]
            digit1_prediction = digit1_logits.max(1)[1]
            digit2_prediction = digit2_logits.max(1)[1]
            digit3_prediction = digit3_logits.max(1)[1]
            digit4_prediction = digit4_logits.max(1)[1]
            digit5_prediction = digit5_logits.max(1)[1]
            output = [
                digit1_prediction.item(),
                digit2_prediction.item(),
                digit3_prediction.item(),
                digit4_prediction.item(),
                digit5_prediction.item(),
            ]

            for i in range(length_prediction.item()):
                outstr += str(output[i])

        return outstr

    except Exception as e:
        return f"{e}"


def get_files(dir_path=MODEL_DIR, ext=".pth"):
    files_and_folders = os.listdir(dir_path)
    outputs = []
    for file in files_and_folders:
        if file.endswith(ext):
            outputs.append(file)

    return outputs


if __name__ == "__main__":
    warnings.filterwarnings("ignore")
    models = get_files()
    images = get_files(f"{MODEL_DIR}/examples", ".png")
    samples = []
    for img in images:
        samples.append(
            [
                f"{MODEL_DIR}/examples/{img}",
                models[random.randint(0, len(models) - 1)],
            ]
        )

    gr.Interface(
        fn=infer,
        inputs=[
            gr.Image(label="Upload an image", type="filepath"),
            gr.Dropdown(
                label="Select a model",
                choices=models,
                value=models[0],
            ),
        ],
        outputs=gr.Textbox(label="Recognition result", show_copy_button=True),
        examples=samples,
        title="Door Number Recognition",
        flagging_mode="never",
        cache_examples=False,
    ).launch()