stanimirovb commited on
Commit
bc314c6
·
verified ·
1 Parent(s): d44690a
Files changed (1) hide show
  1. app.py +48 -4
app.py CHANGED
@@ -2,16 +2,60 @@ import gradio as gr
2
  import numpy as np
3
  import torch
4
  import torchvision
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  def on_submit(img):
7
- img = img['composite'].astype(np.float32)
8
- img = torch.from_numpy(img)
9
- return f"Sum: {img.shape}"
 
 
 
 
 
 
 
 
10
 
11
  iface = gr.Interface(
12
  title = "LeNet",
13
  fn = on_submit,
14
  inputs=gr.Sketchpad(image_mode='P'),
15
- outputs=gr.Label(),
16
  )
17
  iface.launch()
 
2
  import numpy as np
3
  import torch
4
  import torchvision
5
+ from torch import nn
6
 
7
+ class LeNet(nn.Module):
8
+ def __init__(self):
9
+ super(LeNet, self).__init__()
10
+
11
+ self.convs = nn.Sequential(
12
+ nn.Conv2d(in_channels=1, out_channels=4, kernel_size=(5, 5)),
13
+ nn.Tanh(),
14
+ nn.AvgPool2d(2, 2),
15
+
16
+ nn.Conv2d(in_channels=4, out_channels=12, kernel_size=(5, 5)),
17
+ nn.Tanh(),
18
+ nn.AvgPool2d(2, 2)
19
+ )
20
+
21
+ self.linear = nn.Sequential(
22
+ nn.Linear(4*4*12,10)
23
+ )
24
+
25
+ def forward(self, x):
26
+ x = self.convs(x)
27
+ x = torch.flatten(x, 1)
28
+
29
+ return self.linear(x)
30
+
31
+ @torch.no_grad()
32
+ def predict(self, input):
33
+ input = input.reshape(1, 1, 28, 28)
34
+ out = self(input)
35
+ return nn.functional.softmax(out[0], dim = 0)
36
+
37
+ lenet = LeNet()
38
+ lenet.load_state_dict(torch.load('../ibob-lenet-v1/lenet-v1.pth', map_location='cpu'))
39
+
40
+
41
+ resize = torchvision.transforms.Resize((28, 28), antialias=True)
42
  def on_submit(img):
43
+ with torch.no_grad():
44
+ img = img['composite'].astype(np.float32)
45
+ img = torch.from_numpy(img)
46
+ img = resize(img.unsqueeze(0))
47
+
48
+ result = lenet.predict(img)
49
+
50
+ sorted = [[i, e] for i, e in enumerate(result.numpy())]
51
+ sorted.sort(key = lambda a : -a[1])
52
+
53
+ return "\n".join(map(str, sorted))
54
 
55
  iface = gr.Interface(
56
  title = "LeNet",
57
  fn = on_submit,
58
  inputs=gr.Sketchpad(image_mode='P'),
59
+ outputs=gr.Text(),
60
  )
61
  iface.launch()