Spaces:
Sleeping
Sleeping
# Import the libraries | |
import numpy as np | |
import pandas as pd | |
from tensorflow.keras.models import load_model | |
from tensorflow.keras.preprocessing.image import load_img, img_to_array | |
from tensorflow.keras.applications.convnext import preprocess_input | |
import gradio as gr | |
# Load the model | |
model = load_model('models/ConvNeXtBase_80_tresh_spp.tf') | |
# Load the taxonomy .csv | |
taxo_df = pd.read_csv('taxonomy/taxonomy_mapping.csv') | |
taxo_df['species'] = taxo_df['species'].str.replace('_', ' ') | |
# Available taxonomic levels | |
taxonomic_levels = ['species', 'genus', 'family', 'order', 'class'] | |
# Function to map predicted class index to class name at the selected taxonomic level | |
def get_class_name(predicted_class, taxonomic_level): | |
unique_labels = sorted(taxo_df[taxonomic_level].unique()) | |
return unique_labels[predicted_class] | |
# Function to aggregate predictions to a higher taxonomic level | |
def aggregate_predictions(predicted_probs, taxonomic_level, class_names): | |
unique_labels = sorted(taxo_df[taxonomic_level].unique()) | |
aggregated_predictions = np.zeros((predicted_probs.shape[0], len(unique_labels))) | |
for idx, row in taxo_df.iterrows(): | |
species = row['species'] | |
higher_level = row[taxonomic_level] | |
species_index = class_names.index(species) # Index of the species in the prediction array | |
higher_level_index = unique_labels.index(higher_level) | |
aggregated_predictions[:, higher_level_index] += predicted_probs[:, species_index] | |
return aggregated_predictions, unique_labels | |
# Function to load and preprocess the image | |
def load_and_preprocess_image(image, target_size=(224, 224)): | |
# Resize the image | |
img_array = img_to_array(image.resize(target_size)) | |
# Expand the dimensions to match model input | |
img_array = np.expand_dims(img_array, axis=0) | |
# Preprocess the image | |
img_array = preprocess_input(img_array) | |
return img_array | |
# Function to make predictions | |
def make_prediction(image, taxonomic_level): | |
# Preprocess the image | |
img_array = load_and_preprocess_image(image) | |
# Get the class names from the 'species' column | |
class_names = sorted(taxo_df['species'].unique()) # Add this line to define class_names | |
# Make a prediction | |
prediction = model.predict(img_array) | |
# Aggregate predictions based on the selected taxonomic level | |
aggregated_predictions, aggregated_class_labels = aggregate_predictions(prediction, taxonomic_level, class_names) | |
# Get the top 5 predictions | |
top_indices = np.argsort(aggregated_predictions[0])[-5:][::-1] | |
# Get predicted class for the top prediction | |
predicted_class_index = np.argmax(aggregated_predictions) | |
predicted_class_name = aggregated_class_labels[predicted_class_index] | |
# Check if common name should be displayed (only at species level) | |
if taxonomic_level == "species": | |
predicted_common_name = taxo_df[taxo_df[taxonomic_level] == predicted_class_name]['common_name'].values[0] | |
output_text = f"<h1 style='font-weight: bold;'><span style='font-style: italic;'>{predicted_class_name}</span> ({predicted_common_name})</h1>" | |
else: | |
output_text = f"<h1 style='font-weight: bold;'>{predicted_class_name}</h1>" | |
# Add the top 5 predictions | |
output_text += "<h4 style='font-weight: bold; font-size: 1.2em;'>Top 5 Predictions:</h4>" | |
for i in top_indices: | |
class_name = aggregated_class_labels[i] | |
if taxonomic_level == "species": | |
# Display common names only at species level and make it italic | |
common_name = taxo_df[taxo_df[taxonomic_level] == class_name]['common_name'].values[0] | |
confidence_percentage = aggregated_predictions[0][i] * 100 | |
output_text += f"<div style='display: flex; justify-content: space-between;'>" \ | |
f"<span style='font-style: italic;'>{class_name}</span> (<span>{common_name}</span>)" \ | |
f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>" | |
else: | |
# No common names at higher taxonomic levels | |
confidence_percentage = aggregated_predictions[0][i] * 100 | |
output_text += f"<div style='display: flex; justify-content: space-between;'>" \ | |
f"<span>{class_name}</span>" \ | |
f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>" | |
return output_text | |
# Define a function to update the welcome message based on the logged-in user | |
def update_message(request: gr.Request): | |
return f"Welcome to the demo, Dr. {request.username}!" | |
# Define the Gradio interface | |
with gr.Blocks() as demo: | |
# Add a Markdown component for displaying the welcome message | |
welcome_message = gr.Markdown() | |
# Load the update_message function to display the welcome message | |
demo.load(update_message, None, welcome_message) | |
# Define the main interface for predictions | |
interface = gr.Interface( | |
fn=make_prediction, # Function to be called for predictions | |
inputs=[gr.Image(type="pil"), # Input type: Image (PIL format) | |
gr.Dropdown(choices=taxonomic_levels, label="Taxonomic level", value="species")], # Use 'value' instead of 'default' | |
outputs="html", # Output type: HTML for formatting | |
title="Amazon arboreal species classification", | |
description="Upload an image and select the taxonomic level to classify the species." | |
) | |
# Add the prediction interface to the main demo | |
interface.render() | |
# Launch the Gradio interface with authentication for the specified users | |
demo.launch(auth=[ | |
("Luca Santini", "lucasantini"), | |
("Ana Ben铆tez L贸pez", "anaben铆tezl贸pez") | |
]) |