andrewzamp commited on
Commit
f147a22
1 Parent(s): ec9383f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -77
app.py CHANGED
@@ -16,6 +16,11 @@ taxo_df['species'] = taxo_df['species'].str.replace('_', ' ')
16
  # Available taxonomic levels
17
  taxonomic_levels = ['species', 'genus', 'family', 'order', 'class']
18
 
 
 
 
 
 
19
  # Function to aggregate predictions to a higher taxonomic level
20
  def aggregate_predictions(predicted_probs, taxonomic_level, class_names):
21
  unique_labels = sorted(taxo_df[taxonomic_level].unique())
@@ -42,8 +47,8 @@ def load_and_preprocess_image(image, target_size=(224, 224)):
42
  img_array = preprocess_input(img_array)
43
  return img_array
44
 
45
- # Function to make predictions when taxonomic level is specified
46
- def make_prediction_with_taxonomic_level(image, taxonomic_level):
47
  # Preprocess the image
48
  img_array = load_and_preprocess_image(image)
49
 
@@ -78,12 +83,13 @@ def make_prediction_with_taxonomic_level(image, taxonomic_level):
78
 
79
  if taxonomic_level == "species":
80
  # Display common names only at species level and make it italic
81
- common_name = taxo_df[taxo_level == class_name]['common_name'].values[0]
82
  confidence_percentage = aggregated_predictions[0][i] * 100
83
  output_text += f"<div style='display: flex; justify-content: space-between;'>" \
84
  f"<span style='font-style: italic;'>{class_name}</span>&nbsp;(<span>{common_name}</span>)" \
85
  f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
86
  else:
 
87
  confidence_percentage = aggregated_predictions[0][i] * 100
88
  output_text += f"<div style='display: flex; justify-content: space-between;'>" \
89
  f"<span>{class_name}</span>" \
@@ -91,79 +97,33 @@ def make_prediction_with_taxonomic_level(image, taxonomic_level):
91
 
92
  return output_text
93
 
94
- # Function to make predictions with automatic taxonomic resolution
95
- def make_prediction_auto(image):
96
- # Preprocess the image
97
- img_array = load_and_preprocess_image(image)
98
-
99
- # Get the class names from the 'species' column
100
- class_names = sorted(taxo_df['species'].unique())
101
-
102
- # Make a prediction
103
- prediction = model.predict(img_array)
104
-
105
- # Start with species-level predictions
106
- taxonomic_level = 'species'
107
- aggregated_predictions, aggregated_class_labels = aggregate_predictions(prediction, taxonomic_level, class_names)
108
-
109
- # Check confidence and move to higher taxonomic levels if necessary
110
- predicted_class_index = np.argmax(aggregated_predictions)
111
- confidence = aggregated_predictions[0][predicted_class_index]
112
-
113
- while confidence < 0.80 and taxonomic_levels.index(taxonomic_level) < len(taxonomic_levels) - 1:
114
- # Move to the next higher taxonomic level
115
- taxonomic_level = taxonomic_levels[taxonomic_levels.index(taxonomic_level) + 1]
116
- aggregated_predictions, aggregated_class_labels = aggregate_predictions(prediction, taxonomic_level, class_names)
117
- predicted_class_index = np.argmax(aggregated_predictions)
118
- confidence = aggregated_predictions[0][predicted_class_index]
119
-
120
- predicted_class_name = aggregated_class_labels[predicted_class_index]
121
-
122
- if taxonomic_level == "species":
123
- predicted_common_name = taxo_df[taxo_df[taxonomic_level] == predicted_class_name]['common_name'].values[0]
124
- output_text = f"<h1 style='font-weight: bold;'><span style='font-style: italic;'>{predicted_class_name}</span> ({predicted_common_name})</h1>"
125
- else:
126
- output_text = f"<h1 style='font-weight: bold;'>{predicted_class_name}</h1>"
127
-
128
- # Return the final prediction text
129
- return output_text
130
-
131
- # Gradio function to handle the flag logic
132
- def make_prediction(image, choose_resolution, taxonomic_level):
133
- if choose_resolution == "Yes, I want to specify the taxonomic level":
134
- return make_prediction_with_taxonomic_level(image, taxonomic_level)
135
- else:
136
- return make_prediction_auto(image)
137
-
138
- # Function to dynamically disable/enable the taxonomic level dropdown
139
- def update_taxonomic_level_interface(choose_resolution):
140
- if choose_resolution == "No, I will let the model decide":
141
- return gr.Dropdown.update(interactive=False)
142
- else:
143
- return gr.Dropdown.update(interactive=True)
144
 
145
  # Define the Gradio interface
146
- with gr.Blocks() as interface:
147
- with gr.Row():
148
- image_input = gr.Image(type="pil")
149
- choose_resolution = gr.Radio(
150
- choices=["Yes, I want to specify the taxonomic level", "No, I will let the model decide"],
151
- label="Do you want to choose the taxonomic resolution for predictions?", value="No, I will let the model decide"
152
- )
153
- taxonomic_level = gr.Dropdown(
154
- choices=taxonomic_levels, label="Taxonomic level", value="species", interactive=True
155
- )
156
-
157
- result_output = gr.HTML()
158
-
159
- # Add a button to submit the image for prediction
160
- predict_button = gr.Button("Predict")
161
-
162
- # Bind the dynamic dropdown control
163
- choose_resolution.change(fn=update_taxonomic_level_interface, inputs=choose_resolution, outputs=taxonomic_level)
164
-
165
- # Bind the button click to the prediction function
166
- predict_button.click(fn=make_prediction, inputs=[image_input, choose_resolution, taxonomic_level], outputs=result_output)
167
-
168
- # Launch the Gradio interface
169
- interface.launch()
 
 
16
  # Available taxonomic levels
17
  taxonomic_levels = ['species', 'genus', 'family', 'order', 'class']
18
 
19
+ # Function to map predicted class index to class name at the selected taxonomic level
20
+ def get_class_name(predicted_class, taxonomic_level):
21
+ unique_labels = sorted(taxo_df[taxonomic_level].unique())
22
+ return unique_labels[predicted_class]
23
+
24
  # Function to aggregate predictions to a higher taxonomic level
25
  def aggregate_predictions(predicted_probs, taxonomic_level, class_names):
26
  unique_labels = sorted(taxo_df[taxonomic_level].unique())
 
47
  img_array = preprocess_input(img_array)
48
  return img_array
49
 
50
+ # Function to make predictions
51
+ def make_prediction(image, taxonomic_level):
52
  # Preprocess the image
53
  img_array = load_and_preprocess_image(image)
54
 
 
83
 
84
  if taxonomic_level == "species":
85
  # Display common names only at species level and make it italic
86
+ common_name = taxo_df[taxo_df[taxonomic_level] == class_name]['common_name'].values[0]
87
  confidence_percentage = aggregated_predictions[0][i] * 100
88
  output_text += f"<div style='display: flex; justify-content: space-between;'>" \
89
  f"<span style='font-style: italic;'>{class_name}</span>&nbsp;(<span>{common_name}</span>)" \
90
  f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
91
  else:
92
+ # No common names at higher taxonomic levels
93
  confidence_percentage = aggregated_predictions[0][i] * 100
94
  output_text += f"<div style='display: flex; justify-content: space-between;'>" \
95
  f"<span>{class_name}</span>" \
 
97
 
98
  return output_text
99
 
100
+ # Define a function to update the welcome message based on the logged-in user
101
+ def update_message(request: gr.Request):
102
+ return f"Welcome to the demo, Dr. {request.username}!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  # Define the Gradio interface
105
+ with gr.Blocks() as demo:
106
+ # Add a Markdown component for displaying the welcome message
107
+ welcome_message = gr.Markdown()
108
+
109
+ # Load the update_message function to display the welcome message
110
+ demo.load(update_message, None, welcome_message)
111
+
112
+ # Define the main interface for predictions
113
+ interface = gr.Interface(
114
+ fn=make_prediction, # Function to be called for predictions
115
+ inputs=[gr.Image(type="pil"), # Input type: Image (PIL format)
116
+ gr.Dropdown(choices=taxonomic_levels, label="Taxonomic level", value="species")], # Use 'value' instead of 'default'
117
+ outputs="html", # Output type: HTML for formatting
118
+ title="Amazon arboreal species classification",
119
+ description="Upload an image and select the taxonomic level to classify the species."
120
+ )
121
+
122
+ # Add the prediction interface to the main demo
123
+ interface.render()
124
+
125
+ # Launch the Gradio interface with authentication for the specified users
126
+ demo.launch(auth=[
127
+ ("Luca Santini", "lucasantini"),
128
+ ("Ana Ben铆tez L贸pez", "anaben铆tezl贸pez")
129
+ ])