alexluna4 commited on
Commit
795d005
·
verified ·
1 Parent(s): 594904d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -8
app.py CHANGED
@@ -32,15 +32,19 @@ PATH = "state_dict_model.pth" # PATH where you load the model trained
32
  load_model.load_state_dict(torch.load(PATH))
33
  load_model.eval()
34
  def recognize_digit(image):
35
- if image is not None:
36
- # Preprocess of the image
 
 
37
  transform = transforms.Compose([
38
- transforms.ToTensor(),
39
- transforms.Normalize((0.5,), (0.5,))
40
- ])
41
- image = transform(image)
42
- with torch.inference_mode(): # inference mode of pytoroch
43
- prediction = load_model(image)
 
 
44
  prediction = torch.softmax(prediction, dim=1)
45
  return {str(i): float(prediction[0][i]) for i in range(10)}
46
  else:
 
32
  load_model.load_state_dict(torch.load(PATH))
33
  load_model.eval()
34
  def recognize_digit(image):
35
+ if sketch is not None:
36
+ # Procesamiento del dibujo
37
+ sketch = transforms.ToPILImage()(sketch).convert("L").resize((28, 28))
38
+ sketch = PIL.ImageOps.invert(sketch)
39
  transform = transforms.Compose([
40
+ transforms.ToTensor(),
41
+ transforms.Normalize((0.5,), (0.5,))
42
+ ])
43
+ sketch = transform(sketch).unsqueeze(0)
44
+
45
+ with torch.no_grad():
46
+ prediction = load_model(sketch)
47
+
48
  prediction = torch.softmax(prediction, dim=1)
49
  return {str(i): float(prediction[0][i]) for i in range(10)}
50
  else: