Jfink09 commited on
Commit
ff07067
·
1 Parent(s): 2b919c3

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +86 -0
  2. examples/cr44.jpg +0 -0
  3. examples/dr240.jpg +0 -0
  4. examples/mdegen228.jpg +0 -0
  5. model.py +36 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### 1. Imports and class names setup ###
2
+ import gradio as gr
3
+ import os
4
+ import torch
5
+
6
+ from model import create_efficientnet_b0_model
7
+ from timeit import default_timer as timer
8
+ from typing import Tuple, Dict
9
+
10
+ # Setup class names
11
+ class_names = ['CRVO',
12
+ 'Choroidal Nevus',
13
+ 'Diabetic Retinopathy',
14
+ 'Laser Spots',
15
+ 'Macular Degeneration',
16
+ 'Macular Hole',
17
+ 'Myelinated Nerve Fiber',
18
+ 'Normal',
19
+ 'Pathological Mypoia',
20
+ 'Retinitis Pigmentosa']
21
+
22
+ ### 2. Model and transforms preparation ###
23
+
24
+ # Create EfficientNet_B0 model
25
+ efficientnet_b0, efficientnet_b0_transforms = create_efficientnet_b0_model(
26
+ num_classes=len(class_names), # actual value would also work
27
+ )
28
+
29
+ # Load saved weights
30
+ efficientnet_b0.load_state_dict(
31
+ torch.load(
32
+ f="pretrained_efficientnet_b0_feature_extractor_drappcompressed.pth",
33
+ map_location=torch.device("cpu"), # load to CPU
34
+ )
35
+ )
36
+
37
+ ### 3. Predict function ###
38
+
39
+ # Create predict function
40
+ def predict(img) -> Tuple[Dict, float]:
41
+ """Transforms and performs a prediction on img and returns prediction and time taken.
42
+ """
43
+ # Start the timer
44
+ start_time = timer()
45
+
46
+ # Transform the target image and add a batch dimension
47
+ img = efficientnet_b0_transforms(img).unsqueeze(0)
48
+
49
+ # Put model into evaluation mode and turn on inference mode
50
+ efficientnet_b0.eval()
51
+ with torch.inference_mode():
52
+ # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
53
+ pred_probs = torch.softmax(efficientnet_b0(img), dim=1)
54
+
55
+ # Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
56
+ pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
57
+
58
+ # Calculate the prediction time
59
+ pred_time = round(timer() - start_time, 5)
60
+
61
+ # Return the prediction dictionary and prediction time
62
+ return pred_labels_and_probs, pred_time
63
+
64
+ ### 4. Gradio app ###
65
+
66
+ # Create title, description and article strings
67
+ #title = "DeepFundus 👀"
68
+ #description = "A EfficientNet_B0 feature extractor computer vision model to classify funduscopic images."
69
+ #article = "Created with the help from [09. PyTorch Model Deployment](https://www.learnpytorch.io/09_pytorch_model_deployment/)."
70
+
71
+ # Create examples list from "examples/" directory
72
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
73
+
74
+ # Create the Gradio demo
75
+ demo = gr.Interface(fn=predict, # mapping function from input to output
76
+ inputs=gr.Image(type="pil"), # what are the inputs?
77
+ outputs=[gr.Label(num_top_classes=3, label="Predictions"), # what are the outputs?
78
+ gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
79
+ # Create examples list from "examples/" directory
80
+ examples=example_list)
81
+ #title=title,
82
+ #description=description,
83
+ #article=article)
84
+
85
+ # Launch the demo!
86
+ demo.launch()
examples/cr44.jpg ADDED
examples/dr240.jpg ADDED
examples/mdegen228.jpg ADDED
model.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+
4
+ from torch import nn
5
+
6
+ def create_efficientnet_b0_model(num_classes:int=10, # 4
7
+ seed:int=42):
8
+ """Creates an EfficientNet_B0 feature extractor model and transforms.
9
+
10
+ Args:
11
+ num_classes (int, optional): number of classes in the classifier head.
12
+ Defaults to 3.
13
+ seed (int, optional): random seed value. Defaults to 42.
14
+
15
+ Returns:
16
+ model (torch.nn.Module): EfficientNet_B0 feature extractor model.
17
+ transforms (torchvision.transforms): EfficientNet_B0 image transforms.
18
+ """
19
+ # 1, 2, 3. Create EfficientNet_B0 pretrained weights, transforms and model
20
+ weights = torchvision.models.EfficientNet_B0_Weights.DEFAULT
21
+ transforms = weights.transforms()
22
+ model = torchvision.models.efficientnet_b0(weights=weights)
23
+
24
+ # 4. Freeze all layers in base model
25
+ for param in model.parameters():
26
+ param.requires_grad = True # Set to False for model's other than ResNet
27
+
28
+ # 5. Change classifier head with random seed for reproducibility
29
+ torch.manual_seed(seed)
30
+ model.classifier = nn.Sequential(
31
+ nn.Dropout(p=0.3, inplace=True),
32
+ nn.Linear(in_features=1280
33
+ , out_features=num_classes), # If using EffnetB2 in_features = 1408, EffnetB0 in_features = 1280, if ResNet50 in_features = 2048
34
+ )
35
+
36
+ return model, transforms