Jfink09 commited on
Commit
a9d8111
·
verified ·
1 Parent(s): 168004e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -54
app.py CHANGED
@@ -8,8 +8,6 @@ from model import create_resnet50_model
8
  from timeit import default_timer as timer
9
  from typing import Tuple, Dict
10
 
11
- import torch.nn.functional as F
12
-
13
  # Setup class names
14
  class_names = ['CRVO',
15
  'Choroidal Nevus',
@@ -42,63 +40,29 @@ resnet50.load_state_dict(
42
  ### 3. Predict function ###
43
 
44
  # Create predict function
45
- # def predict(img) -> Tuple[Dict, float]:
46
- # """Transforms and performs a prediction on img and returns prediction and time taken.
47
- # """
48
- # # Start the timer
49
- # start_time = timer()
50
-
51
- # # Transform the target image and add a batch dimension
52
- # img = resnet50_transforms(img).unsqueeze(0)
53
 
54
- # # Put model into evaluation mode and turn on inference mode
55
- # resnet50.eval()
56
- # with torch.inference_mode():
57
- # # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
58
- # pred_probs = torch.softmax(resnet50(img), dim=1)
59
 
60
- # # Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
61
- # pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
 
 
 
62
 
63
- # # Calculate the prediction time
64
- # pred_time = round(timer() - start_time, 5)
65
 
66
- # # Return the prediction dictionary and prediction time
67
- # return pred_labels_and_probs, pred_time
68
-
69
- def predict(img):
70
- """Transforms and performs a prediction on img and returns prediction and time taken."""
71
- start_time = timer()
72
 
73
- try:
74
- img = resnet50_transforms(img).unsqueeze(0)
75
- resnet50.eval()
76
-
77
- with torch.inference_mode():
78
- pred_probs = torch.softmax(resnet50(img), dim=1)
79
-
80
- # Calculate entropy for OOD detection
81
- entropy = -torch.sum(pred_probs * torch.log(pred_probs + 1e-8)).item()
82
- max_prob = torch.max(pred_probs).item()
83
-
84
- # Create base prediction dictionary
85
- pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
86
-
87
- # OOD Detection - modify existing probabilities instead of adding new keys
88
- if (max_prob > 0.95 and entropy < 0.05) or entropy > 0.1:
89
- # Boost the probability of the first class and add a marker
90
- pred_labels_and_probs[class_names[0]] = 0.99 # Use existing class
91
- # You could also just print a warning or log it
92
- print("May not be retina scan")
93
-
94
- pred_time = round(timer() - start_time, 5)
95
- return pred_labels_and_probs, pred_time
96
-
97
- except Exception as e:
98
- # Return dictionary with same structure as normal case
99
- pred_labels_and_probs = {class_names[i]: 0.0 for i in range(len(class_names))}
100
- pred_labels_and_probs[class_names[0]] = 1.0 # Show error in first class
101
- return pred_labels_and_probs, 0.0
102
 
103
  ### 4. Gradio app ###
104
 
 
8
  from timeit import default_timer as timer
9
  from typing import Tuple, Dict
10
 
 
 
11
  # Setup class names
12
  class_names = ['CRVO',
13
  'Choroidal Nevus',
 
40
  ### 3. Predict function ###
41
 
42
  # Create predict function
43
+ def predict(img) -> Tuple[Dict, float]:
44
+ """Transforms and performs a prediction on img and returns prediction and time taken.
45
+ """
46
+ # Start the timer
47
+ start_time = timer()
 
 
 
48
 
49
+ # Transform the target image and add a batch dimension
50
+ img = resnet50_transforms(img).unsqueeze(0)
 
 
 
51
 
52
+ # Put model into evaluation mode and turn on inference mode
53
+ resnet50.eval()
54
+ with torch.inference_mode():
55
+ # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
56
+ pred_probs = torch.softmax(resnet50(img), dim=1)
57
 
58
+ # Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
59
+ pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
60
 
61
+ # Calculate the prediction time
62
+ pred_time = round(timer() - start_time, 5)
 
 
 
 
63
 
64
+ # Return the prediction dictionary and prediction time
65
+ return pred_labels_and_probs, pred_time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  ### 4. Gradio app ###
68