Jfink09 commited on
Commit
9fe98d4
·
verified ·
1 Parent(s): 1378a1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -23
app.py CHANGED
@@ -66,33 +66,45 @@ resnet50.load_state_dict(
66
  # # Return the prediction dictionary and prediction time
67
  # return pred_labels_and_probs, pred_time
68
 
69
- def predict(img) -> Tuple[Dict, float]:
70
  """Transforms and performs a prediction on img and returns prediction and time taken."""
71
- # Start the timer
72
- start_time = timer()
73
- # Transform the target image and add a batch dimension
74
- img = resnet50_transforms(img).unsqueeze(0)
75
- # Put model into evaluation mode and turn on inference mode
76
- resnet50.eval()
77
- with torch.inference_mode():
78
- # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
79
- pred_probs = torch.softmax(resnet50(img), dim=1)
80
 
81
- # Calculate entropy for OOD detection
82
- entropy = -torch.sum(pred_probs * torch.log(pred_probs + 1e-8)).item()
83
- max_prob = torch.max(pred_probs).item()
84
 
85
- # Create a prediction label and prediction probability dictionary for each prediction class
86
- pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
87
 
88
- # OOD Detection: Flag suspicious predictions
89
- if (max_prob > 0.95 and entropy < 0.2) or entropy > 2.0:
90
- pred_labels_and_probs["May not be a retina scan"] = 0.99
91
-
92
- # Calculate the prediction time
93
- pred_time = round(timer() - start_time, 5)
94
- # Return the prediction dictionary and prediction time
95
- return pred_labels_and_probs, pred_time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  ### 4. Gradio app ###
98
 
 
66
  # # Return the prediction dictionary and prediction time
67
  # return pred_labels_and_probs, pred_time
68
 
69
+ def predict(img):
70
  """Transforms and performs a prediction on img and returns prediction and time taken."""
71
+ try:
72
+ # Start the timer
73
+ start_time = timer()
 
 
 
 
 
 
74
 
75
+ # Ensure img is valid
76
+ if img is None:
77
+ return {"Error": "No image provided"}, 0.0
78
 
79
+ # Transform the target image and add a batch dimension
80
+ img_tensor = resnet50_transforms(img).unsqueeze(0)
81
 
82
+ # Put model into evaluation mode and turn on inference mode
83
+ resnet50.eval()
84
+ with torch.inference_mode():
85
+ # Pass the transformed image through the model
86
+ pred_probs = torch.softmax(resnet50(img_tensor), dim=1)
87
+
88
+ # Calculate entropy for OOD detection
89
+ entropy = -torch.sum(pred_probs * torch.log(pred_probs + 1e-8)).item()
90
+ max_prob = torch.max(pred_probs).item()
91
+
92
+ # Create prediction dictionary
93
+ pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
94
+
95
+ # OOD Detection
96
+ if (max_prob > 0.95 and entropy < 0.2) or entropy > 2.0:
97
+ pred_labels_and_probs["May not be retina scan"] = 0.99
98
+
99
+ # Calculate prediction time
100
+ pred_time = round(timer() - start_time, 5)
101
+
102
+ return pred_labels_and_probs, pred_time
103
+
104
+ except Exception as e:
105
+ # Return error information in expected format
106
+ error_dict = {"Processing Error": 1.0}
107
+ return error_dict, 0.0
108
 
109
  ### 4. Gradio app ###
110