Spaces:
Running
Running
import time | |
import gradio as gr | |
import torch | |
import torch.nn.functional as F | |
from torchvision.utils import save_image | |
from huggingface_hub import hf_hub_download | |
from digitnet import Model | |
torch.set_grad_enabled(False) | |
hf_hub_download("karanravindra/digitnet", filename="model.ckpt", local_dir=".") | |
model = Model.load_from_checkpoint("model.ckpt", map_location="cpu") | |
model.eval() | |
all_classes = list(map(str.upper, "0123456789")) | |
def predict(inputs: dict) -> dict[str, float]: | |
print(f"[{time.strftime('%H:%M:%S')}] Predicting...") | |
img = torch.from_numpy(inputs["composite"]).unsqueeze(0).unsqueeze(0).float() | |
img = img / img.max() | |
img = F.interpolate(img, size=(32, 32), mode="bilinear", align_corners=False) | |
logits = model(img) | |
probs = torch.softmax(logits, dim=1).squeeze() | |
return {c: p for c, p in zip(all_classes, probs.tolist())} | |
input = gr.Sketchpad( | |
image_mode="L", | |
canvas_size=(600, 600), | |
layers=False, | |
brush=gr.Brush(colors=["rgb(239, 68, 68)"], color_mode="fixed", default_size=20), | |
) | |
output = gr.Label(num_top_classes=10, label="Prediction") | |
demo = gr.Interface( | |
fn=predict, | |
live=True, | |
inputs=input, | |
outputs=output, | |
clear_btn=gr.Button("Clear"), | |
title="DigitNet", | |
description="""Welcome to **DigitNet**, a tool for recognizing handwritten numbers. | |
### How to Use DigitNet | |
1. **Draw a Number**: Use the canvas on the left to draw a number (0-9). | |
2. **Edit as Needed**: Use the eraser or **Clear** to fix or reset. | |
3. **See Prediction**: View the predicted number and confidence score. | |
4. **Try Again**: Click **Clear** to draw again. | |
### Tips for Best Results | |
- Draw in the middle of the canvas. | |
- Experiment with different writing styles.""", | |
flagging_mode="never" | |
) | |
if __name__ == "__main__": | |
demo.launch() | |