andrewzamp commited on
Commit
9f8bab0
·
1 Parent(s): 4fe087a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -13
app.py CHANGED
@@ -51,23 +51,45 @@ def load_and_preprocess_image(image, target_size=(224, 224)):
51
  def make_prediction(image, taxonomic_level):
52
  # Preprocess the image
53
  img_array = load_and_preprocess_image(image)
54
-
55
  # Get the class names from the 'species' column
56
- class_names = sorted(taxo_df['species'].unique()) # Add this line to define class_names
57
-
58
  # Make a prediction
59
  prediction = model.predict(img_array)
60
-
61
- # Aggregate predictions based on the selected taxonomic level
62
- aggregated_predictions, aggregated_class_labels = aggregate_predictions(prediction, taxonomic_level, class_names)
63
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  # Get the top 5 predictions
65
  top_indices = np.argsort(aggregated_predictions[0])[-5:][::-1]
66
-
67
  # Get predicted class for the top prediction
68
  predicted_class_index = np.argmax(aggregated_predictions)
69
  predicted_class_name = aggregated_class_labels[predicted_class_index]
70
-
71
  # Check if common name should be displayed (only at species level)
72
  if taxonomic_level == "species":
73
  predicted_common_name = taxo_df[taxo_df[taxonomic_level] == predicted_class_name]['common_name'].values[0]
@@ -77,10 +99,10 @@ def make_prediction(image, taxonomic_level):
77
 
78
  # Add the top 5 predictions
79
  output_text += "<h4 style='font-weight: bold; font-size: 1.2em;'>Top 5 Predictions:</h4>"
80
-
81
  for i in top_indices:
82
  class_name = aggregated_class_labels[i]
83
-
84
  if taxonomic_level == "species":
85
  # Display common names only at species level and make it italic
86
  common_name = taxo_df[taxo_df[taxonomic_level] == class_name]['common_name'].values[0]
@@ -94,7 +116,7 @@ def make_prediction(image, taxonomic_level):
94
  output_text += f"<div style='display: flex; justify-content: space-between;'>" \
95
  f"<span>{class_name}</span>" \
96
  f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
97
-
98
  return output_text
99
 
100
  # Define the Gradio interface
@@ -103,7 +125,7 @@ interface = gr.Interface(
103
  inputs=[
104
  gr.Image(type="pil", label="Upload Image"), # Input type: Image (PIL format)
105
  gr.Radio(choices=["Yes, I want to specify the taxonomic level", "No, I will let the model decide"],
106
- label="Do you want to specify the taxonomic resolution for predictions?<br><br>NOTE: if you select 'No', the next drop-down menu will be bypassed",
107
  value="No, I will let the model decide"), # Radio button for taxonomic resolution choice
108
  gr.Dropdown(choices=taxonomic_levels, label="Taxonomic level", value="species") # Dropdown for taxonomic level
109
  ],
 
51
  def make_prediction(image, taxonomic_level):
52
  # Preprocess the image
53
  img_array = load_and_preprocess_image(image)
54
+
55
  # Get the class names from the 'species' column
56
+ class_names = sorted(taxo_df['species'].unique())
57
+
58
  # Make a prediction
59
  prediction = model.predict(img_array)
60
+
61
+ # Initialize the current taxonomic level index based on the user selection
62
+ current_taxonomic_level_index = taxonomic_levels.index(taxonomic_level)
63
+
64
+ # If the user chose to let the model decide, check the confidence levels
65
+ if taxonomic_level == "No, I will let the model decide":
66
+ aggregated_predictions = prediction
67
+ while current_taxonomic_level_index < len(taxonomic_levels):
68
+ # Aggregate predictions based on the current taxonomic level
69
+ aggregated_predictions, aggregated_class_labels = aggregate_predictions(
70
+ aggregated_predictions, taxonomic_levels[current_taxonomic_level_index], class_names
71
+ )
72
+
73
+ # Check if the max confidence in the aggregated predictions is >= 0.80
74
+ if np.max(aggregated_predictions) >= 0.80:
75
+ break
76
+
77
+ # Move to the next higher taxonomic level
78
+ current_taxonomic_level_index += 1
79
+
80
+ # Update the taxonomic level for output
81
+ taxonomic_level = taxonomic_levels[current_taxonomic_level_index]
82
+ else:
83
+ # Aggregate predictions based on the selected taxonomic level
84
+ aggregated_predictions, aggregated_class_labels = aggregate_predictions(prediction, taxonomic_level, class_names)
85
+
86
  # Get the top 5 predictions
87
  top_indices = np.argsort(aggregated_predictions[0])[-5:][::-1]
88
+
89
  # Get predicted class for the top prediction
90
  predicted_class_index = np.argmax(aggregated_predictions)
91
  predicted_class_name = aggregated_class_labels[predicted_class_index]
92
+
93
  # Check if common name should be displayed (only at species level)
94
  if taxonomic_level == "species":
95
  predicted_common_name = taxo_df[taxo_df[taxonomic_level] == predicted_class_name]['common_name'].values[0]
 
99
 
100
  # Add the top 5 predictions
101
  output_text += "<h4 style='font-weight: bold; font-size: 1.2em;'>Top 5 Predictions:</h4>"
102
+
103
  for i in top_indices:
104
  class_name = aggregated_class_labels[i]
105
+
106
  if taxonomic_level == "species":
107
  # Display common names only at species level and make it italic
108
  common_name = taxo_df[taxo_df[taxonomic_level] == class_name]['common_name'].values[0]
 
116
  output_text += f"<div style='display: flex; justify-content: space-between;'>" \
117
  f"<span>{class_name}</span>" \
118
  f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
119
+
120
  return output_text
121
 
122
  # Define the Gradio interface
 
125
  inputs=[
126
  gr.Image(type="pil", label="Upload Image"), # Input type: Image (PIL format)
127
  gr.Radio(choices=["Yes, I want to specify the taxonomic level", "No, I will let the model decide"],
128
+ label="Do you want to specify the taxonomic resolution for predictions? If you select 'No', the next drop-down menu will be bypassed.",
129
  value="No, I will let the model decide"), # Radio button for taxonomic resolution choice
130
  gr.Dropdown(choices=taxonomic_levels, label="Taxonomic level", value="species") # Dropdown for taxonomic level
131
  ],