Spaces:
Sleeping
Sleeping
Commit
·
d33d3aa
1
Parent(s):
9669b45
Update app.py
Browse files
app.py
CHANGED
@@ -30,19 +30,10 @@ def aggregate_predictions(predicted_probs, taxonomic_level, class_names):
|
|
30 |
species = row['species']
|
31 |
higher_level = row[taxonomic_level]
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
higher_level_index = unique_labels.index(higher_level)
|
38 |
-
|
39 |
-
# Only update if indices are valid
|
40 |
-
if species_index < predicted_probs.shape[1] and higher_level_index < aggregated_predictions.shape[1]:
|
41 |
-
if predicted_probs[:, species_index].max() >= 0.80: # Check confidence level
|
42 |
-
aggregated_predictions[:, higher_level_index] += predicted_probs[:, species_index]
|
43 |
-
else:
|
44 |
-
# Stop aggregation at the current level if confidence is below 0.80
|
45 |
-
break
|
46 |
|
47 |
return aggregated_predictions, unique_labels
|
48 |
|
@@ -57,51 +48,26 @@ def load_and_preprocess_image(image, target_size=(224, 224)):
|
|
57 |
return img_array
|
58 |
|
59 |
# Function to make predictions
|
60 |
-
def make_prediction(image,
|
61 |
# Preprocess the image
|
62 |
img_array = load_and_preprocess_image(image)
|
63 |
-
|
64 |
# Get the class names from the 'species' column
|
65 |
-
class_names = sorted(taxo_df['species'].unique())
|
66 |
-
|
67 |
# Make a prediction
|
68 |
prediction = model.predict(img_array)
|
69 |
-
|
70 |
-
#
|
71 |
-
|
72 |
-
|
73 |
-
# If the user chose to let the model decide, check the confidence levels
|
74 |
-
if taxonomic_resolution_choice == "No, I will let the model decide":
|
75 |
-
aggregated_predictions = prediction
|
76 |
-
while current_taxonomic_level_index < len(taxonomic_levels):
|
77 |
-
# Aggregate predictions based on the current taxonomic level
|
78 |
-
aggregated_predictions, aggregated_class_labels = aggregate_predictions(
|
79 |
-
aggregated_predictions, taxonomic_levels[current_taxonomic_level_index], class_names
|
80 |
-
)
|
81 |
-
|
82 |
-
# Check if the max confidence in the aggregated predictions is >= 0.80
|
83 |
-
if np.max(aggregated_predictions) >= 0.80:
|
84 |
-
break
|
85 |
-
|
86 |
-
# Move to the next higher taxonomic level
|
87 |
-
current_taxonomic_level_index += 1
|
88 |
-
|
89 |
-
# Ensure we don't go out of bounds
|
90 |
-
if current_taxonomic_level_index < len(taxonomic_levels):
|
91 |
-
taxonomic_level = taxonomic_levels[current_taxonomic_level_index]
|
92 |
-
else:
|
93 |
-
taxonomic_level = taxonomic_levels[-1] # fallback to the highest taxonomic level
|
94 |
-
else:
|
95 |
-
# Aggregate predictions based on the selected taxonomic level
|
96 |
-
aggregated_predictions, aggregated_class_labels = aggregate_predictions(prediction, taxonomic_level, class_names)
|
97 |
-
|
98 |
# Get the top 5 predictions
|
99 |
top_indices = np.argsort(aggregated_predictions[0])[-5:][::-1]
|
100 |
-
|
101 |
# Get predicted class for the top prediction
|
102 |
predicted_class_index = np.argmax(aggregated_predictions)
|
103 |
predicted_class_name = aggregated_class_labels[predicted_class_index]
|
104 |
-
|
105 |
# Check if common name should be displayed (only at species level)
|
106 |
if taxonomic_level == "species":
|
107 |
predicted_common_name = taxo_df[taxo_df[taxonomic_level] == predicted_class_name]['common_name'].values[0]
|
@@ -111,10 +77,10 @@ def make_prediction(image, taxonomic_resolution_choice, taxonomic_level):
|
|
111 |
|
112 |
# Add the top 5 predictions
|
113 |
output_text += "<h4 style='font-weight: bold; font-size: 1.2em;'>Top 5 Predictions:</h4>"
|
114 |
-
|
115 |
for i in top_indices:
|
116 |
class_name = aggregated_class_labels[i]
|
117 |
-
|
118 |
if taxonomic_level == "species":
|
119 |
# Display common names only at species level and make it italic
|
120 |
common_name = taxo_df[taxo_df[taxonomic_level] == class_name]['common_name'].values[0]
|
@@ -128,7 +94,7 @@ def make_prediction(image, taxonomic_resolution_choice, taxonomic_level):
|
|
128 |
output_text += f"<div style='display: flex; justify-content: space-between;'>" \
|
129 |
f"<span>{class_name}</span>" \
|
130 |
f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
|
131 |
-
|
132 |
return output_text
|
133 |
|
134 |
# Define the Gradio interface
|
@@ -143,7 +109,7 @@ interface = gr.Interface(
|
|
143 |
],
|
144 |
outputs="html", # Output type: HTML for formatting
|
145 |
title="Amazon arboreal species classification",
|
146 |
-
description="Upload an image and select the taxonomic level to classify the species."
|
147 |
)
|
148 |
|
149 |
# Launch the Gradio interface with authentication for the specified users
|
|
|
30 |
species = row['species']
|
31 |
higher_level = row[taxonomic_level]
|
32 |
|
33 |
+
species_index = class_names.index(species) # Index of the species in the prediction array
|
34 |
+
higher_level_index = unique_labels.index(higher_level)
|
35 |
+
|
36 |
+
aggregated_predictions[:, higher_level_index] += predicted_probs[:, species_index]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
return aggregated_predictions, unique_labels
|
39 |
|
|
|
48 |
return img_array
|
49 |
|
50 |
# Function to make predictions
|
51 |
+
def make_prediction(image, taxonomic_level):
|
52 |
# Preprocess the image
|
53 |
img_array = load_and_preprocess_image(image)
|
54 |
+
|
55 |
# Get the class names from the 'species' column
|
56 |
+
class_names = sorted(taxo_df['species'].unique()) # Add this line to define class_names
|
57 |
+
|
58 |
# Make a prediction
|
59 |
prediction = model.predict(img_array)
|
60 |
+
|
61 |
+
# Aggregate predictions based on the selected taxonomic level
|
62 |
+
aggregated_predictions, aggregated_class_labels = aggregate_predictions(prediction, taxonomic_level, class_names)
|
63 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
# Get the top 5 predictions
|
65 |
top_indices = np.argsort(aggregated_predictions[0])[-5:][::-1]
|
66 |
+
|
67 |
# Get predicted class for the top prediction
|
68 |
predicted_class_index = np.argmax(aggregated_predictions)
|
69 |
predicted_class_name = aggregated_class_labels[predicted_class_index]
|
70 |
+
|
71 |
# Check if common name should be displayed (only at species level)
|
72 |
if taxonomic_level == "species":
|
73 |
predicted_common_name = taxo_df[taxo_df[taxonomic_level] == predicted_class_name]['common_name'].values[0]
|
|
|
77 |
|
78 |
# Add the top 5 predictions
|
79 |
output_text += "<h4 style='font-weight: bold; font-size: 1.2em;'>Top 5 Predictions:</h4>"
|
80 |
+
|
81 |
for i in top_indices:
|
82 |
class_name = aggregated_class_labels[i]
|
83 |
+
|
84 |
if taxonomic_level == "species":
|
85 |
# Display common names only at species level and make it italic
|
86 |
common_name = taxo_df[taxo_df[taxonomic_level] == class_name]['common_name'].values[0]
|
|
|
94 |
output_text += f"<div style='display: flex; justify-content: space-between;'>" \
|
95 |
f"<span>{class_name}</span>" \
|
96 |
f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
|
97 |
+
|
98 |
return output_text
|
99 |
|
100 |
# Define the Gradio interface
|
|
|
109 |
],
|
110 |
outputs="html", # Output type: HTML for formatting
|
111 |
title="Amazon arboreal species classification",
|
112 |
+
description="Upload an image and select the taxonomic level (optional) to classify the species."
|
113 |
)
|
114 |
|
115 |
# Launch the Gradio interface with authentication for the specified users
|