Spaces:
Running
Running
File size: 1,819 Bytes
b664581 beb2ed5 b664581 beb2ed5 b664581 beb2ed5 b664581 beb2ed5 b664581 6800ab4 b664581 6800ab4 b664581 6800ab4 beb2ed5 b664581 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import time
import gradio as gr
import torch
import torch.nn.functional as F
from torchvision.utils import save_image
from huggingface_hub import hf_hub_download
from digitnet import Model
torch.set_grad_enabled(False)
hf_hub_download("karanravindra/digitnet", filename="model.ckpt", local_dir=".")
model = Model.load_from_checkpoint("model.ckpt", map_location="cpu")
model.eval()
all_classes = list(map(str.upper, "0123456789"))
def predict(inputs: dict) -> dict[str, float]:
print(f"[{time.strftime('%H:%M:%S')}] Predicting...")
img = torch.from_numpy(inputs["composite"]).unsqueeze(0).unsqueeze(0).float()
img = img / img.max()
img = F.interpolate(img, size=(32, 32), mode="bilinear", align_corners=False)
logits = model(img)
probs = torch.softmax(logits, dim=1).squeeze()
return {c: p for c, p in zip(all_classes, probs.tolist())}
input = gr.Sketchpad(
image_mode="L",
canvas_size=(600, 600),
layers=False,
brush=gr.Brush(colors=["rgb(239, 68, 68)"], color_mode="fixed", default_size=20),
)
output = gr.Label(num_top_classes=10, label="Prediction")
demo = gr.Interface(
fn=predict,
live=True,
inputs=input,
outputs=output,
clear_btn=gr.Button("Clear"),
title="DigitNet",
description="""Welcome to **DigitNet**, a tool for recognizing handwritten numbers.
### How to Use DigitNet
1. **Draw a Number**: Use the canvas on the left to draw a number (0-9).
2. **Edit as Needed**: Use the eraser or **Clear** to fix or reset.
3. **See Prediction**: View the predicted number and confidence score.
4. **Try Again**: Click **Clear** to draw again.
### Tips for Best Results
- Draw in the middle of the canvas.
- Experiment with different writing styles.""",
flagging_mode="never"
)
if __name__ == "__main__":
demo.launch()
|