File size: 4,753 Bytes
a0a2a37
 
 
 
 
5490f7a
a0a2a37
 
 
 
dbfeb08
 
a0a2a37
 
1569e58
a0a2a37
 
 
1f48633
a0a2a37
 
16c3a47
a9f2449
a0a2a37
 
 
 
 
 
 
0983653
a0a2a37
16c3a47
a0a2a37
 
 
 
 
 
 
6584f0f
a0a2a37
 
 
dbfeb08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fe98d4
dbfeb08
c1df69a
 
9fe98d4
c1df69a
9fe98d4
c1df69a
9fe98d4
c1df69a
9fe98d4
 
 
 
 
c1df69a
9fe98d4
 
c1df69a
9fe98d4
c1df69a
 
 
 
9fe98d4
 
 
 
 
c1df69a
 
 
 
a0a2a37
 
 
 
db44c43
 
 
a0a2a37
 
 
 
 
 
 
5beb3d6
a0a2a37
 
16c3a47
 
 
 
a0a2a37
 
6584f0f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
### 1. Imports and class names setup ### 
import gradio as gr
import os
import torch

from torch import nn
from model import create_resnet50_model
from timeit import default_timer as timer
from typing import Tuple, Dict

import torch.nn.functional as F

# Setup class names
class_names = ['CRVO',
  'Choroidal Nevus',
  'Diabetic Retinopathy',
  'Laser Spots',
  'Macular Degeneration',
  'Macular Hole',
  'Myelinated Nerve Fiber',
  'Normal',
  'Pathological Mypoia',
  'Retinitis Pigmentosa']

### 2. Model and transforms preparation ###

# Create ResNet50 model
resnet50, resnet50_transforms = create_resnet50_model(
    num_classes=len(class_names), # actual value would also work
)
resnet50.fc = nn.Linear(2048, 10)

# Load saved weights
resnet50.load_state_dict(
    torch.load(
        f="pretrained_resnet50_feature_extractor_drappcompressed.pth",
        map_location=torch.device("cpu"),  # load to CPU
    )
)


### 3. Predict function ###

# Create predict function
# def predict(img) -> Tuple[Dict, float]:
#     """Transforms and performs a prediction on img and returns prediction and time taken.
#     """
#     # Start the timer
#     start_time = timer()
    
#     # Transform the target image and add a batch dimension
#     img = resnet50_transforms(img).unsqueeze(0)
    
#     # Put model into evaluation mode and turn on inference mode
#     resnet50.eval()
#     with torch.inference_mode():
#         # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
#         pred_probs = torch.softmax(resnet50(img), dim=1)
    
#     # Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
#     pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
    
#     # Calculate the prediction time
#     pred_time = round(timer() - start_time, 5)
    
#     # Return the prediction dictionary and prediction time 
#     return pred_labels_and_probs, pred_time

def predict(img):
    """Transforms and performs a prediction on img and returns prediction and time taken."""
    start_time = timer()
    
    try:
        img = resnet50_transforms(img).unsqueeze(0)
        resnet50.eval()
        
        with torch.inference_mode():
            pred_probs = torch.softmax(resnet50(img), dim=1)
            
            # Calculate entropy for OOD detection
            entropy = -torch.sum(pred_probs * torch.log(pred_probs + 1e-8)).item()
            max_prob = torch.max(pred_probs).item()
            
            # Create base prediction dictionary
            pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
            
            # OOD Detection - modify existing probabilities instead of adding new keys
            if (max_prob > 0.95 and entropy < 0.2) or entropy > 2.0:
                # Boost the probability of the first class and add a marker
                pred_labels_and_probs[class_names[0]] = 0.99  # Use existing class
                # You could also just print a warning or log it
                print("May not be retina scan")
        
        pred_time = round(timer() - start_time, 5)
        return pred_labels_and_probs, pred_time
        
    except Exception as e:
        # Return dictionary with same structure as normal case
        pred_labels_and_probs = {class_names[i]: 0.0 for i in range(len(class_names))}
        pred_labels_and_probs[class_names[0]] = 1.0  # Show error in first class
        return pred_labels_and_probs, 0.0

### 4. Gradio app ###

# Create title, description and article strings
#title = "DeepFundus 👀"
#description = "A ResNet50 feature extractor computer vision model to classify funduscopic images."
#article = "Created with the help from [09. PyTorch Model Deployment](https://www.learnpytorch.io/09_pytorch_model_deployment/)."

# Create examples list from "examples/" directory
example_list = [["examples/" + example] for example in os.listdir("examples")]

# Create the Gradio demo
demo = gr.Interface(fn=predict, # mapping function from input to output
                    inputs=gr.Image(type="pil"), # what are the inputs?
                    outputs=[gr.Label(num_top_classes=3, label="Predictions"), # what are the outputs?
                             gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
                    # Create examples list from "examples/" directory
                    examples=example_list) 
                    #title=title,
                    #description=description,
                    #article=article)

# Launch the demo!
demo.launch()