Jfink09 commited on
Commit
dbfeb08
·
verified ·
1 Parent(s): daf866e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -9
app.py CHANGED
@@ -8,6 +8,8 @@ from model import create_resnet50_model
8
  from timeit import default_timer as timer
9
  from typing import Tuple, Dict
10
 
 
 
11
  # Setup class names
12
  class_names = ['CRVO',
13
  'Choroidal Nevus',
@@ -40,28 +42,56 @@ resnet50.load_state_dict(
40
  ### 3. Predict function ###
41
 
42
  # Create predict function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def predict(img) -> Tuple[Dict, float]:
44
- """Transforms and performs a prediction on img and returns prediction and time taken.
45
- """
46
  # Start the timer
47
  start_time = timer()
48
-
49
  # Transform the target image and add a batch dimension
50
  img = resnet50_transforms(img).unsqueeze(0)
51
-
52
  # Put model into evaluation mode and turn on inference mode
53
  resnet50.eval()
54
  with torch.inference_mode():
55
  # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
56
  pred_probs = torch.softmax(resnet50(img), dim=1)
57
-
58
- # Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
59
- pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
 
 
 
 
 
 
 
 
60
 
61
  # Calculate the prediction time
62
  pred_time = round(timer() - start_time, 5)
63
-
64
- # Return the prediction dictionary and prediction time
65
  return pred_labels_and_probs, pred_time
66
 
67
  ### 4. Gradio app ###
 
8
  from timeit import default_timer as timer
9
  from typing import Tuple, Dict
10
 
11
+ import torch.nn.functional as F
12
+
13
  # Setup class names
14
  class_names = ['CRVO',
15
  'Choroidal Nevus',
 
42
  ### 3. Predict function ###
43
 
44
  # Create predict function
45
+ # def predict(img) -> Tuple[Dict, float]:
46
+ # """Transforms and performs a prediction on img and returns prediction and time taken.
47
+ # """
48
+ # # Start the timer
49
+ # start_time = timer()
50
+
51
+ # # Transform the target image and add a batch dimension
52
+ # img = resnet50_transforms(img).unsqueeze(0)
53
+
54
+ # # Put model into evaluation mode and turn on inference mode
55
+ # resnet50.eval()
56
+ # with torch.inference_mode():
57
+ # # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
58
+ # pred_probs = torch.softmax(resnet50(img), dim=1)
59
+
60
+ # # Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
61
+ # pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
62
+
63
+ # # Calculate the prediction time
64
+ # pred_time = round(timer() - start_time, 5)
65
+
66
+ # # Return the prediction dictionary and prediction time
67
+ # return pred_labels_and_probs, pred_time
68
+
69
  def predict(img) -> Tuple[Dict, float]:
70
+ """Transforms and performs a prediction on img and returns prediction and time taken."""
 
71
  # Start the timer
72
  start_time = timer()
 
73
  # Transform the target image and add a batch dimension
74
  img = resnet50_transforms(img).unsqueeze(0)
 
75
  # Put model into evaluation mode and turn on inference mode
76
  resnet50.eval()
77
  with torch.inference_mode():
78
  # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
79
  pred_probs = torch.softmax(resnet50(img), dim=1)
80
+
81
+ # Calculate entropy for OOD detection
82
+ entropy = -torch.sum(pred_probs * torch.log(pred_probs + 1e-8)).item()
83
+ max_prob = torch.max(pred_probs).item()
84
+
85
+ # Create a prediction label and prediction probability dictionary for each prediction class
86
+ pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
87
+
88
+ # OOD Detection: Flag suspicious predictions
89
+ if (max_prob > 0.95 and entropy < 0.2) or entropy > 2.0:
90
+ pred_labels_and_probs["May not be a retina scan"] = 0.99
91
 
92
  # Calculate the prediction time
93
  pred_time = round(timer() - start_time, 5)
94
+ # Return the prediction dictionary and prediction time
 
95
  return pred_labels_and_probs, pred_time
96
 
97
  ### 4. Gradio app ###