Spaces:
Sleeping
Sleeping
File size: 7,016 Bytes
9f198ef d33d3aa 9f198ef ad3d1fe c80713a 68bd012 c80713a 68bd012 c80713a e014372 22a6037 e493bfa 22a6037 c80713a 098b334 c80713a e014372 c80713a e014372 c80713a d33d3aa c80713a 9f198ef 1364165 e93a56e 96a0181 e93a56e 68bd012 e93a56e 1364165 d33d3aa 1364165 f147a22 1364165 9f198ef 50d52be f147a22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# Import the libraries
import numpy as np
import pandas as pd
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.applications.convnext import preprocess_input
import gradio as gr
# Load the model
model = load_model('models/ConvNeXtBase_80_tresh_spp.tf')
# Load the taxonomy .csv
taxo_df = pd.read_csv('taxonomy/taxonomy_mapping.csv')
taxo_df['species'] = taxo_df['species'].str.replace('_', ' ')
# Available taxonomic levels
taxonomic_levels = ['species', 'genus', 'family', 'order', 'class']
# Function to map predicted class index to class name at the selected taxonomic level
def get_class_name(predicted_class, taxonomic_level):
unique_labels = sorted(taxo_df[taxonomic_level].unique())
return unique_labels[predicted_class]
# Function to aggregate predictions to a higher taxonomic level
def aggregate_predictions(predicted_probs, taxonomic_level, class_names):
unique_labels = sorted(taxo_df[taxonomic_level].unique())
aggregated_predictions = np.zeros((predicted_probs.shape[0], len(unique_labels)))
for idx, row in taxo_df.iterrows():
species = row['species']
higher_level = row[taxonomic_level]
species_index = class_names.index(species) # Index of the species in the prediction array
higher_level_index = unique_labels.index(higher_level)
aggregated_predictions[:, higher_level_index] += predicted_probs[:, species_index]
return aggregated_predictions, unique_labels
# Function to load and preprocess the image
def load_and_preprocess_image(image, target_size=(224, 224)):
# Resize the image
img_array = img_to_array(image.resize(target_size))
# Expand the dimensions to match model input
img_array = np.expand_dims(img_array, axis=0)
# Preprocess the image
img_array = preprocess_input(img_array)
return img_array
# Function to make predictions
def make_prediction(image, taxonomic_decision, taxonomic_level):
# Preprocess the image
img_array = load_and_preprocess_image(image)
# Get the class names from the 'species' column
class_names = sorted(taxo_df['species'].unique())
# Make a prediction
prediction = model.predict(img_array)
# Initialize variables for aggregated predictions and level index
aggregated_predictions = None
current_level_index = 0 # Start from the species level
# Determine the initial taxonomic level based on the user's decision
if taxonomic_decision == "No, I will let the model decide":
current_level_index = 0 # Start at species level if letting the model decide
else:
current_level_index = taxonomic_levels.index(taxonomic_level) # Use specified level
# Loop through taxonomic levels to check confidence
while current_level_index < len(taxonomic_levels):
# Aggregate predictions based on the current taxonomic level
aggregated_predictions, aggregated_class_labels = aggregate_predictions(prediction, taxonomic_levels[current_level_index], class_names)
# Check if the confidence of the top prediction meets the threshold
top_prediction_index = np.argmax(aggregated_predictions)
top_prediction_confidence = aggregated_predictions[0][top_prediction_index]
if top_prediction_confidence >= 0.80:
break # Confidence threshold met, exit loop
current_level_index += 1 # Move to the next taxonomic level
# Check if a valid prediction was made
if current_level_index == len(taxonomic_levels):
return "<h1 style='font-weight: bold;'>Unknown animal</h1>" # No valid predictions met the confidence criteria
# Get the predicted class name for the top prediction
predicted_class_index = np.argmax(aggregated_predictions)
predicted_class_name = aggregated_class_labels[predicted_class_index]
# Check if common name should be displayed (only at species level)
if taxonomic_levels[current_level_index] == "species":
predicted_common_name = taxo_df[taxo_df[taxonomic_levels[current_level_index]] == predicted_class_name]['common_name'].values[0]
output_text = f"<h1 style='font-weight: bold;'><span style='font-style: italic;'>{predicted_class_name}</span> ({predicted_common_name})</h1>"
else:
output_text = f"<h1 style='font-weight: bold;'>{predicted_class_name}</h1>"
# Add the top 5 predictions
output_text += "<h4 style='font-weight: bold; font-size: 1.2em;'>Top-5 predictions:</h4>"
top_indices = np.argsort(aggregated_predictions[0])[-5:][::-1] # Get top 5 predictions
for i in top_indices:
class_name = aggregated_class_labels[i]
if taxonomic_levels[current_level_index] == "species":
# Display common names only at species level and make it italic
common_name = taxo_df[taxo_df[taxonomic_levels[current_level_index]] == class_name]['common_name'].values[0]
confidence_percentage = aggregated_predictions[0][i] * 100
output_text += f"<div style='display: flex; justify-content: space-between;'>" \
f"<span style='font-style: italic;'>{class_name}</span> (<span>{common_name}</span>)" \
f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
else:
# No common names at higher taxonomic levels
confidence_percentage = aggregated_predictions[0][i] * 100
output_text += f"<div style='display: flex; justify-content: space-between;'>" \
f"<span>{class_name}</span>" \
f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
return output_text
# Define the Gradio interface
interface = gr.Interface(
fn=make_prediction, # Function to be called for predictions
inputs=[
gr.Image(type="pil", label="Upload Image"), # Input type: Image (PIL format)
gr.Radio(choices=["Yes, I want to specify the taxonomic level", "No, I will let the model decide"],
label="Do you want to specify the taxonomic resolution for predictions? If you select 'No', the 'Taxonomic level' drop-down menu will be bypassed.",
value="No, I will let the model decide"), # Radio button for taxonomic resolution choice
gr.Dropdown(choices=taxonomic_levels, label="Taxonomic level:", value="species") # Dropdown for taxonomic level
],
outputs="html", # Output type: HTML for formatting
title="Amazon arboreal species classification",
description="Upload an image and select the taxonomic level (optional) to classify the species."
)
# Launch the Gradio interface with authentication for the specified users
interface.launch(auth=[
("Andrea Zampetti", "andreazampetti"),
("Luca Santini", "lucasantini"),
("Ana Benítez", "anabenítez")
]) |