File size: 1,463 Bytes
b664581
 
 
 
 
beb2ed5
b664581
 
 
 
 
 
beb2ed5
b664581
beb2ed5
b664581
 
beb2ed5
b664581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
beb2ed5
b664581
 
beb2ed5
b664581
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import time

import gradio as gr
import torch
import torch.nn.functional as F
from torchvision.utils import save_image
from huggingface_hub import hf_hub_download

from digitnet import Model

torch.set_grad_enabled(False)

hf_hub_download("karanravindra/digitnet", filename="model.ckpt", local_dir=".")

model = Model.load_from_checkpoint("model.ckpt", map_location="cpu")
model.eval()

all_classes = list(map(str.upper, "0123456789"))


def predict(inputs: dict) -> dict[str, float]:
    print(f"[{time.strftime('%H:%M:%S')}] Predicting...")

    img = torch.from_numpy(inputs["composite"]).unsqueeze(0).unsqueeze(0).float()
    img = img / img.max()

    img = F.interpolate(img, size=(32, 32), mode="bilinear", align_corners=False)

    logits = model(img)

    probs = torch.softmax(logits, dim=1).squeeze()
    return {c: p for c, p in zip(all_classes, probs.tolist())}


input = gr.Sketchpad(
    image_mode="L",
    canvas_size=(600, 600),
    layers=False,
    brush=gr.Brush(colors=["rgb(239, 68, 68)"], color_mode="fixed", default_size=20),
)

output = gr.Label(num_top_classes=10)

demo = gr.Interface(
    fn=predict,
    live=True,
    inputs=input,
    outputs="label",
    submit_btn=gr.Button("Predict"),
    title="DigitNet",
    description="A simple handwritten number and letter classifier.\n\nDraw a digit or letter in the box below and see the model's predictions.",
    flagging_mode="never"
)

if __name__ == "__main__":
    demo.launch()