alexluna4 commited on
Commit
5d5354d
·
verified ·
1 Parent(s): 2b6d8dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -15
app.py CHANGED
@@ -32,25 +32,22 @@ PATH = "state_dict_model.pth" # PATH where you load the model trained
32
  load_model.load_state_dict(torch.load(PATH))
33
  load_model.eval()
34
  def recognize_digit(image):
35
- if sketch is not None:
36
- # Procesamiento del dibujo
37
- sketch = transforms.ToPILImage()(sketch).convert("L").resize((28, 28))
38
- sketch = PIL.ImageOps.invert(sketch)
39
  transform = transforms.Compose([
40
- transforms.ToTensor(),
41
- transforms.Normalize((0.5,), (0.5,))
42
- ])
43
- sketch = transform(sketch).unsqueeze(0)
44
-
45
- with torch.no_grad():
46
- prediction = load_model(sketch)
47
-
48
  prediction = torch.softmax(prediction, dim=1)
49
  return {str(i): float(prediction[0][i]) for i in range(10)}
50
  else:
51
  return ""
 
52
 
53
  demo = gr.Interface(fn=recognize_digit,
54
- inputs="sketchpad",
55
- outputs=gr.outputs.Label(num_top_classes=3)
56
- demo.launch(share=True)
 
32
  load_model.load_state_dict(torch.load(PATH))
33
  load_model.eval()
34
  def recognize_digit(image):
35
+ if image is not None:
36
+ # Preprocess of the image
 
 
37
  transform = transforms.Compose([
38
+ transforms.ToTensor(),
39
+ transforms.Normalize((0.5,), (0.5,))
40
+ ])
41
+ image = transform(image)
42
+ with torch.inference_mode(): # inference mode of pytoroch
43
+ prediction = load_model(image)
 
 
44
  prediction = torch.softmax(prediction, dim=1)
45
  return {str(i): float(prediction[0][i]) for i in range(10)}
46
  else:
47
  return ""
48
+
49
 
50
  demo = gr.Interface(fn=recognize_digit,
51
+ inputs=gr.Image(shape=(28,28), image_mode="L", invert_colors=True, source="canvas"),
52
+ outputs=gr.Label(num_top_classes=1))
53
+ demo.launch(True)