lneduchal commited on
Commit
02c950a
·
1 Parent(s): 0ded0c3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -0
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import models, transforms, datasets
3
+
4
+ from PIL import Image
5
+
6
+ import gradio as gr
7
+
8
+
9
+ model_ft = models.resnet18(pretrained = True)
10
+ num_ftrs = model_ft.fc.in_features
11
+ model_ft.fc = nn.Linear(num_ftrs, 2)
12
+
13
+ state_dict = torch.load("up500Model.pt", map_location = "cpu")
14
+
15
+ model_ft.load_state_dict(state_dict)
16
+ model_ft.eval()
17
+
18
+ img_transforms = transforms.Compose(
19
+ [
20
+ transforms.Resize(256),
21
+ transforms.CenterCrop(224),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
24
+ ]
25
+ )
26
+
27
+ labels = ["fiat500", "VW Up!"]
28
+ def predict(img):
29
+ inp = img.fromarray(inp.astype("unit8"), "RGB")
30
+ inp = img_transforms(inp).unsqueeze(0)
31
+
32
+ # We don't want to compute gradients
33
+ with torch.no_grad():
34
+ preds = torch.np.functional.softmax(model_ft(inp)[0])
35
+
36
+ return {labels[i]: preds[i] for i in range(2)}
37
+
38
+ interface = gr.Interface(
39
+ predict,
40
+ inputs = "image",
41
+ output = "label",
42
+ title = "Car classification"
43
+ )
44
+ interface.launch()