Jfink09 commited on
Commit
c1df69a
·
verified ·
1 Parent(s): 9fe98d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -21
app.py CHANGED
@@ -68,43 +68,37 @@ resnet50.load_state_dict(
68
 
69
  def predict(img):
70
  """Transforms and performs a prediction on img and returns prediction and time taken."""
 
 
71
  try:
72
- # Start the timer
73
- start_time = timer()
74
-
75
- # Ensure img is valid
76
- if img is None:
77
- return {"Error": "No image provided"}, 0.0
78
-
79
- # Transform the target image and add a batch dimension
80
- img_tensor = resnet50_transforms(img).unsqueeze(0)
81
-
82
- # Put model into evaluation mode and turn on inference mode
83
  resnet50.eval()
 
84
  with torch.inference_mode():
85
- # Pass the transformed image through the model
86
- pred_probs = torch.softmax(resnet50(img_tensor), dim=1)
87
 
88
  # Calculate entropy for OOD detection
89
  entropy = -torch.sum(pred_probs * torch.log(pred_probs + 1e-8)).item()
90
  max_prob = torch.max(pred_probs).item()
91
 
92
- # Create prediction dictionary
93
  pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
94
 
95
- # OOD Detection
96
  if (max_prob > 0.95 and entropy < 0.2) or entropy > 2.0:
97
- pred_labels_and_probs["May not be retina scan"] = 0.99
 
 
 
98
 
99
- # Calculate prediction time
100
  pred_time = round(timer() - start_time, 5)
101
-
102
  return pred_labels_and_probs, pred_time
103
 
104
  except Exception as e:
105
- # Return error information in expected format
106
- error_dict = {"Processing Error": 1.0}
107
- return error_dict, 0.0
 
108
 
109
  ### 4. Gradio app ###
110
 
 
68
 
69
  def predict(img):
70
  """Transforms and performs a prediction on img and returns prediction and time taken."""
71
+ start_time = timer()
72
+
73
  try:
74
+ img = resnet50_transforms(img).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
75
  resnet50.eval()
76
+
77
  with torch.inference_mode():
78
+ pred_probs = torch.softmax(resnet50(img), dim=1)
 
79
 
80
  # Calculate entropy for OOD detection
81
  entropy = -torch.sum(pred_probs * torch.log(pred_probs + 1e-8)).item()
82
  max_prob = torch.max(pred_probs).item()
83
 
84
+ # Create base prediction dictionary
85
  pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
86
 
87
+ # OOD Detection - modify existing probabilities instead of adding new keys
88
  if (max_prob > 0.95 and entropy < 0.2) or entropy > 2.0:
89
+ # Boost the probability of the first class and add a marker
90
+ pred_labels_and_probs[class_names[0]] = 0.99 # Use existing class
91
+ # You could also just print a warning or log it
92
+ print("May not be retina scan")
93
 
 
94
  pred_time = round(timer() - start_time, 5)
 
95
  return pred_labels_and_probs, pred_time
96
 
97
  except Exception as e:
98
+ # Return dictionary with same structure as normal case
99
+ pred_labels_and_probs = {class_names[i]: 0.0 for i in range(len(class_names))}
100
+ pred_labels_and_probs[class_names[0]] = 1.0 # Show error in first class
101
+ return pred_labels_and_probs, 0.0
102
 
103
  ### 4. Gradio app ###
104