Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -32,15 +32,19 @@ PATH = "state_dict_model.pth" # PATH where you load the model trained
|
|
32 |
load_model.load_state_dict(torch.load(PATH))
|
33 |
load_model.eval()
|
34 |
def recognize_digit(image):
|
35 |
-
if
|
36 |
-
#
|
|
|
|
|
37 |
transform = transforms.Compose([
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
44 |
prediction = torch.softmax(prediction, dim=1)
|
45 |
return {str(i): float(prediction[0][i]) for i in range(10)}
|
46 |
else:
|
|
|
32 |
load_model.load_state_dict(torch.load(PATH))
|
33 |
load_model.eval()
|
34 |
def recognize_digit(image):
|
35 |
+
if sketch is not None:
|
36 |
+
# Procesamiento del dibujo
|
37 |
+
sketch = transforms.ToPILImage()(sketch).convert("L").resize((28, 28))
|
38 |
+
sketch = PIL.ImageOps.invert(sketch)
|
39 |
transform = transforms.Compose([
|
40 |
+
transforms.ToTensor(),
|
41 |
+
transforms.Normalize((0.5,), (0.5,))
|
42 |
+
])
|
43 |
+
sketch = transform(sketch).unsqueeze(0)
|
44 |
+
|
45 |
+
with torch.no_grad():
|
46 |
+
prediction = load_model(sketch)
|
47 |
+
|
48 |
prediction = torch.softmax(prediction, dim=1)
|
49 |
return {str(i): float(prediction[0][i]) for i in range(10)}
|
50 |
else:
|