VanLinLin commited on
Commit
3d70783
·
verified ·
1 Parent(s): 8d7c63f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### 1. Imports and class names app ###
2
+ import gradio as gr
3
+ import os
4
+ import torch
5
+
6
+ from model import create_effnetb2_model
7
+ from timeit import default_timer as timer
8
+ from typing import Tuple, Dict
9
+
10
+ # Setup class names
11
+ with open('class_names.txt', 'r') as f:
12
+ class_names = [food_name.strip() for food_name in f.readlines()]
13
+
14
+ ### 2. Model and transforms preparation ###
15
+ # Create model and transforms
16
+ effnetb2, effnetb2_transforms = create_effnetb2_model()
17
+
18
+ # Load saved weight
19
+ effnetb2.load_state_dict(torch.load(f='09_pretrained_effnetb2_feature_extractor_food101_20_percent.pth',
20
+ map_location=torch.device('cpu'))) # load to cpu
21
+
22
+ ### 3. Predict function
23
+ def predict(img) -> Tuple[Dict, float]:
24
+ # Start a timer
25
+ start_time = timer()
26
+
27
+ # Transform the input image for use with EffNetB2
28
+ transformed_img = effnetb2_transforms(img).unsqueeze(dim=0) # unsqueeze = add batch dimension on 0th
29
+
30
+ # Put model into eval mode, make prediction
31
+ with torch.inference_mode():
32
+ effnetb2.eval()
33
+
34
+ # Pass the transformed image through the model and turn the prediction logits into probabilities
35
+ pred_prob = effnetb2(transformed_img).softmax(dim=1)
36
+
37
+ # Create a prediction label and prediction probability dictionary
38
+ pred_labels_and_probs = {class_names[i]: pred_prob[0][i].item() for i in range(len(class_names))}
39
+
40
+ # Calcualte pred time
41
+ end_time = timer()
42
+ inference_time = round(end_time - start_time, 4)
43
+
44
+ # Return pred dict and pred time
45
+ return pred_labels_and_probs, inference_time
46
+
47
+ ### 4. Gradio app ###
48
+ # Create title, description and aritcle
49
+ title = 'FoodVision Big 🍔👁️💪'
50
+ description = 'An [EfficientNetB2 feature extractor](https://pytorch.org/vision/0.16/models/generated/torchvision.models.efficientnet_b2.html#efficientnet-b2) computer vision model to classify images 101 classes of food from the Food101 dataset.'
51
+ article = 'Created at [11. Turning our FoodVision Big model into a deployable app](https://www.learnpytorch.io/09_pytorch_model_deployment/#11-turning-our-foodvision-big-model-into-a-deployable-app).'
52
+
53
+ # Create example list
54
+ example_list = [['examples/' + example] for example in os.listdir('examples')]
55
+
56
+ # Create the Gradio demo
57
+ demo = gr.Interface(fn=predict, # maps inputs to outputs
58
+ inputs=gr.Image(type='pil'),
59
+ outputs=[gr.Label(num_top_classes=5, label='Predictions'),
60
+ gr.Number(label='Prediction time (s)')],
61
+ examples=example_list,
62
+ title=title,
63
+ description=description,
64
+ article=article)
65
+
66
+ # Launch the demo!
67
+ demo.launch(debug=False, # print errors locally?
68
+ share=True) # generate a publically shareable URL