Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -32,25 +32,22 @@ PATH = "state_dict_model.pth" # PATH where you load the model trained
|
|
32 |
load_model.load_state_dict(torch.load(PATH))
|
33 |
load_model.eval()
|
34 |
def recognize_digit(image):
|
35 |
-
if
|
36 |
-
#
|
37 |
-
sketch = transforms.ToPILImage()(sketch).convert("L").resize((28, 28))
|
38 |
-
sketch = PIL.ImageOps.invert(sketch)
|
39 |
transform = transforms.Compose([
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
prediction = load_model(sketch)
|
47 |
-
|
48 |
prediction = torch.softmax(prediction, dim=1)
|
49 |
return {str(i): float(prediction[0][i]) for i in range(10)}
|
50 |
else:
|
51 |
return ""
|
|
|
52 |
|
53 |
demo = gr.Interface(fn=recognize_digit,
|
54 |
-
inputs="
|
55 |
-
outputs=gr.
|
56 |
-
demo.launch(
|
|
|
32 |
load_model.load_state_dict(torch.load(PATH))
|
33 |
load_model.eval()
|
34 |
def recognize_digit(image):
|
35 |
+
if image is not None:
|
36 |
+
# Preprocess of the image
|
|
|
|
|
37 |
transform = transforms.Compose([
|
38 |
+
transforms.ToTensor(),
|
39 |
+
transforms.Normalize((0.5,), (0.5,))
|
40 |
+
])
|
41 |
+
image = transform(image)
|
42 |
+
with torch.inference_mode(): # inference mode of pytoroch
|
43 |
+
prediction = load_model(image)
|
|
|
|
|
44 |
prediction = torch.softmax(prediction, dim=1)
|
45 |
return {str(i): float(prediction[0][i]) for i in range(10)}
|
46 |
else:
|
47 |
return ""
|
48 |
+
|
49 |
|
50 |
demo = gr.Interface(fn=recognize_digit,
|
51 |
+
inputs=gr.Image(shape=(28,28), image_mode="L", invert_colors=True, source="canvas"),
|
52 |
+
outputs=gr.Label(num_top_classes=1))
|
53 |
+
demo.launch(True)
|