import time import gradio as gr import torch import torch.nn.functional as F from torchvision.utils import save_image from huggingface_hub import hf_hub_download from digitnet import Model torch.set_grad_enabled(False) hf_hub_download("karanravindra/digitnet", filename="model.ckpt", local_dir=".") model = Model.load_from_checkpoint("model.ckpt", map_location="cpu") model.eval() all_classes = list(map(str.upper, "0123456789")) def predict(inputs: dict) -> dict[str, float]: print(f"[{time.strftime('%H:%M:%S')}] Predicting...") img = torch.from_numpy(inputs["composite"]).unsqueeze(0).unsqueeze(0).float() img = img / img.max() img = F.interpolate(img, size=(32, 32), mode="bilinear", align_corners=False) logits = model(img) probs = torch.softmax(logits, dim=1).squeeze() return {c: p for c, p in zip(all_classes, probs.tolist())} input = gr.Sketchpad( image_mode="L", canvas_size=(600, 600), layers=False, brush=gr.Brush(colors=["rgb(239, 68, 68)"], color_mode="fixed", default_size=20), ) output = gr.Label(num_top_classes=10, label="Prediction") demo = gr.Interface( fn=predict, live=True, inputs=input, outputs=output, clear_btn=gr.Button("Clear"), title="DigitNet", description="""Welcome to **DigitNet**, a tool for recognizing handwritten numbers. ### How to Use DigitNet 1. **Draw a Number**: Use the canvas on the left to draw a number (0-9). 2. **Edit as Needed**: Use the eraser or **Clear** to fix or reset. 3. **See Prediction**: View the predicted number and confidence score. 4. **Try Again**: Click **Clear** to draw again. ### Tips for Best Results - Draw in the middle of the canvas. - Experiment with different writing styles.""", flagging_mode="never" ) if __name__ == "__main__": demo.launch()