andrewzamp commited on
Commit
fd81c83
·
1 Parent(s): 7735f05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -14
app.py CHANGED
@@ -16,11 +16,6 @@ taxo_df['species'] = taxo_df['species'].str.replace('_', ' ')
16
  # Available taxonomic levels
17
  taxonomic_levels = ['species', 'genus', 'family', 'order', 'class']
18
 
19
- # Function to map predicted class index to class name at the selected taxonomic level
20
- def get_class_name(predicted_class, taxonomic_level):
21
- unique_labels = sorted(taxo_df[taxonomic_level].unique())
22
- return unique_labels[predicted_class]
23
-
24
  # Function to aggregate predictions to a higher taxonomic level
25
  def aggregate_predictions(predicted_probs, taxonomic_level, class_names):
26
  unique_labels = sorted(taxo_df[taxonomic_level].unique())
@@ -47,8 +42,8 @@ def load_and_preprocess_image(image, target_size=(224, 224)):
47
  img_array = preprocess_input(img_array)
48
  return img_array
49
 
50
- # Function to make predictions
51
- def make_prediction(image, taxonomic_level):
52
  # Preprocess the image
53
  img_array = load_and_preprocess_image(image)
54
 
@@ -83,13 +78,12 @@ def make_prediction(image, taxonomic_level):
83
 
84
  if taxonomic_level == "species":
85
  # Display common names only at species level and make it italic
86
- common_name = taxo_df[taxo_df[taxonomic_level] == class_name]['common_name'].values[0]
87
  confidence_percentage = aggregated_predictions[0][i] * 100
88
  output_text += f"<div style='display: flex; justify-content: space-between;'>" \
89
  f"<span style='font-style: italic;'>{class_name}</span>&nbsp;(<span>{common_name}</span>)" \
90
  f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
91
  else:
92
- # No common names at higher taxonomic levels
93
  confidence_percentage = aggregated_predictions[0][i] * 100
94
  output_text += f"<div style='display: flex; justify-content: space-between;'>" \
95
  f"<span>{class_name}</span>" \
@@ -97,15 +91,72 @@ def make_prediction(image, taxonomic_level):
97
 
98
  return output_text
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  # Define the Gradio interface
101
  interface = gr.Interface(
102
- fn=make_prediction, # Function to be called for predictions
103
- inputs=[gr.Image(type="pil"), # Input type: Image (PIL format)
104
- gr.Dropdown(choices=taxonomic_levels, label="Taxonomic level", value="species")], # Use 'value' instead of 'default'
105
- outputs="html", # Output type: HTML for formatting
 
 
 
106
  title="Amazon arboreal species classification",
107
- description="Upload an image and select the taxonomic level to classify the species."
108
  )
109
 
 
 
 
 
 
 
 
 
 
 
110
  # Launch the Gradio interface
111
  interface.launch()
 
16
  # Available taxonomic levels
17
  taxonomic_levels = ['species', 'genus', 'family', 'order', 'class']
18
 
 
 
 
 
 
19
  # Function to aggregate predictions to a higher taxonomic level
20
  def aggregate_predictions(predicted_probs, taxonomic_level, class_names):
21
  unique_labels = sorted(taxo_df[taxonomic_level].unique())
 
42
  img_array = preprocess_input(img_array)
43
  return img_array
44
 
45
+ # Function to make predictions when taxonomic level is specified
46
+ def make_prediction_with_taxonomic_level(image, taxonomic_level):
47
  # Preprocess the image
48
  img_array = load_and_preprocess_image(image)
49
 
 
78
 
79
  if taxonomic_level == "species":
80
  # Display common names only at species level and make it italic
81
+ common_name = taxo_df[taxo_level == class_name]['common_name'].values[0]
82
  confidence_percentage = aggregated_predictions[0][i] * 100
83
  output_text += f"<div style='display: flex; justify-content: space-between;'>" \
84
  f"<span style='font-style: italic;'>{class_name}</span>&nbsp;(<span>{common_name}</span>)" \
85
  f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
86
  else:
 
87
  confidence_percentage = aggregated_predictions[0][i] * 100
88
  output_text += f"<div style='display: flex; justify-content: space-between;'>" \
89
  f"<span>{class_name}</span>" \
 
91
 
92
  return output_text
93
 
94
+ # Function to make predictions with automatic taxonomic resolution
95
+ def make_prediction_auto(image):
96
+ # Preprocess the image
97
+ img_array = load_and_preprocess_image(image)
98
+
99
+ # Get the class names from the 'species' column
100
+ class_names = sorted(taxo_df['species'].unique())
101
+
102
+ # Make a prediction
103
+ prediction = model.predict(img_array)
104
+
105
+ # Start with species-level predictions
106
+ taxonomic_level = 'species'
107
+ aggregated_predictions, aggregated_class_labels = aggregate_predictions(prediction, taxonomic_level, class_names)
108
+
109
+ # Check confidence and move to higher taxonomic levels if necessary
110
+ predicted_class_index = np.argmax(aggregated_predictions)
111
+ confidence = aggregated_predictions[0][predicted_class_index]
112
+
113
+ while confidence < 0.80 and taxonomic_levels.index(taxonomic_level) < len(taxonomic_levels) - 1:
114
+ # Move to the next higher taxonomic level
115
+ taxonomic_level = taxonomic_levels[taxonomic_levels.index(taxonomic_level) + 1]
116
+ aggregated_predictions, aggregated_class_labels = aggregate_predictions(prediction, taxonomic_level, class_names)
117
+ predicted_class_index = np.argmax(aggregated_predictions)
118
+ confidence = aggregated_predictions[0][predicted_class_index]
119
+
120
+ predicted_class_name = aggregated_class_labels[predicted_class_index]
121
+
122
+ if taxonomic_level == "species":
123
+ predicted_common_name = taxo_df[taxo_df[taxonomic_level] == predicted_class_name]['common_name'].values[0]
124
+ output_text = f"<h1 style='font-weight: bold;'><span style='font-style: italic;'>{predicted_class_name}</span> ({predicted_common_name})</h1>"
125
+ else:
126
+ output_text = f"<h1 style='font-weight: bold;'>{predicted_class_name}</h1>"
127
+
128
+ # Return the final prediction text
129
+ return output_text
130
+
131
+ # Gradio function to handle the flag logic
132
+ def make_prediction(image, choose_resolution, taxonomic_level):
133
+ if choose_resolution == "Yes, I want to specify the taxonomic level":
134
+ return make_prediction_with_taxonomic_level(image, taxonomic_level)
135
+ else:
136
+ return make_prediction_auto(image)
137
+
138
  # Define the Gradio interface
139
  interface = gr.Interface(
140
+ fn=make_prediction,
141
+ inputs=[
142
+ gr.Image(type="pil"),
143
+ gr.Radio(choices=["Yes, I want to specify the taxonomic level", "No, I will let the model decide"],
144
+ label="Do you want to choose the taxonomic resolution for predictions?", value="No, I will let the model decide"),
145
+ gr.Dropdown(choices=taxonomic_levels, label="Taxonomic level", value="species", interactive=True)],
146
+ outputs="html",
147
  title="Amazon arboreal species classification",
148
+ description="Upload an image and select the taxonomic level or let the model decide the resolution."
149
  )
150
 
151
+ # Add custom logic to disable the "Taxonomic level" dropdown when "No, I will let the model decide" is selected
152
+ def update_taxonomic_level_interface(choose_resolution):
153
+ if choose_resolution == "No, I will let the model decide":
154
+ return gr.Dropdown.update(interactive=False)
155
+ else:
156
+ return gr.Dropdown.update(interactive=True)
157
+
158
+ # Set up dynamic behavior for the interface
159
+ interface.update(inputs=["Do you want to choose the taxonomic resolution for predictions?"], fn=update_taxonomic_level_interface)
160
+
161
  # Launch the Gradio interface
162
  interface.launch()