andrewzamp commited on
Commit
8500d79
·
1 Parent(s): a8713e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -16
app.py CHANGED
@@ -71,36 +71,52 @@ def make_prediction(image, taxonomic_decision, taxonomic_level):
71
  # Aggregate predictions based on the current taxonomic level
72
  aggregated_predictions, aggregated_class_labels = aggregate_predictions(prediction, taxonomic_levels[current_level_index], class_names)
73
 
74
- # Check if the confidence of the top prediction meets the threshold
75
- top_prediction_index = np.argmax(aggregated_predictions)
76
- top_prediction_confidence = aggregated_predictions[0][top_prediction_index]
77
-
78
- # If the user specified a taxonomic level, do not automatically promote to a higher level
79
  if taxonomic_decision == "Yes, I want to specify the taxonomic level":
80
- if current_level_index == 0 and top_prediction_confidence < 0.80:
81
- return "<h1 style='font-weight: bold;'>Confidence too low for specified taxonomic level</h1>"
 
 
 
 
 
 
 
 
 
82
 
83
- # Proceed to the next level if the confidence threshold is not met and the user allows it
 
 
 
 
 
 
 
 
 
 
84
  while current_level_index < len(taxonomic_levels):
85
- # Check if confidence is above the threshold at the current level
 
 
 
 
 
 
86
  if top_prediction_confidence >= 0.80:
87
  break # Confidence threshold met, exit loop
88
 
89
  current_level_index += 1 # Move to the next taxonomic level
90
- if current_level_index < len(taxonomic_levels):
91
- # Aggregate predictions for the next level
92
- aggregated_predictions, aggregated_class_labels = aggregate_predictions(prediction, taxonomic_levels[current_level_index], class_names)
93
- top_prediction_index = np.argmax(aggregated_predictions)
94
- top_prediction_confidence = aggregated_predictions[0][top_prediction_index]
95
 
96
  # Check if a valid prediction was made
97
  if current_level_index == len(taxonomic_levels):
98
  return "<h1 style='font-weight: bold;'>Unknown animal</h1>" # No valid predictions met the confidence criteria
99
-
100
  # Get the predicted class name for the top prediction
101
  predicted_class_index = np.argmax(aggregated_predictions)
102
  predicted_class_name = aggregated_class_labels[predicted_class_index]
103
-
104
  # Check if common name should be displayed (only at species level)
105
  if taxonomic_levels[current_level_index] == "species":
106
  predicted_common_name = taxo_df[taxo_df[taxonomic_levels[current_level_index]] == predicted_class_name]['common_name'].values[0]
 
71
  # Aggregate predictions based on the current taxonomic level
72
  aggregated_predictions, aggregated_class_labels = aggregate_predictions(prediction, taxonomic_levels[current_level_index], class_names)
73
 
74
+ # If the user specified a taxonomic level, simply get the highest prediction at that level
 
 
 
 
75
  if taxonomic_decision == "Yes, I want to specify the taxonomic level":
76
+ # Get the predicted class index for the current level
77
+ predicted_class_index = np.argmax(aggregated_predictions)
78
+ predicted_class_name = aggregated_class_labels[predicted_class_index]
79
+
80
+ # Construct the output message without considering confidence
81
+ output_text = f"<h1 style='font-weight: bold;'>{predicted_class_name}</h1>"
82
+
83
+ # Add the top 5 predictions
84
+ output_text += "<h4 style='font-weight: bold; font-size: 1.2em;'>Top-5 predictions:</h4>"
85
+
86
+ top_indices = np.argsort(aggregated_predictions[0])[-5:][::-1] # Get top 5 predictions
87
 
88
+ for i in top_indices:
89
+ class_name = aggregated_class_labels[i]
90
+ confidence_percentage = aggregated_predictions[0][i] * 100
91
+ output_text += f"<div style='display: flex; justify-content: space-between;'>" \
92
+ f"<span>{class_name}</span>" \
93
+ f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
94
+
95
+ return output_text
96
+
97
+ # Confidence checking for the automatic model decision
98
+ # Loop through taxonomic levels if the user lets the model decide
99
  while current_level_index < len(taxonomic_levels):
100
+ # Aggregate predictions for the next level
101
+ aggregated_predictions, aggregated_class_labels = aggregate_predictions(prediction, taxonomic_levels[current_level_index], class_names)
102
+
103
+ # Check if the confidence of the top prediction meets the threshold
104
+ top_prediction_index = np.argmax(aggregated_predictions)
105
+ top_prediction_confidence = aggregated_predictions[0][top_prediction_index]
106
+
107
  if top_prediction_confidence >= 0.80:
108
  break # Confidence threshold met, exit loop
109
 
110
  current_level_index += 1 # Move to the next taxonomic level
 
 
 
 
 
111
 
112
  # Check if a valid prediction was made
113
  if current_level_index == len(taxonomic_levels):
114
  return "<h1 style='font-weight: bold;'>Unknown animal</h1>" # No valid predictions met the confidence criteria
115
+
116
  # Get the predicted class name for the top prediction
117
  predicted_class_index = np.argmax(aggregated_predictions)
118
  predicted_class_name = aggregated_class_labels[predicted_class_index]
119
+
120
  # Check if common name should be displayed (only at species level)
121
  if taxonomic_levels[current_level_index] == "species":
122
  predicted_common_name = taxo_df[taxo_df[taxonomic_levels[current_level_index]] == predicted_class_name]['common_name'].values[0]