Spaces:
Sleeping
Sleeping
File size: 5,834 Bytes
64d14f7 49a2c57 64d14f7 0c48d27 34ecf7a 0c48d27 fb6d3d4 0c48d27 49a2c57 0c48d27 f147a22 2759d0e 49a2c57 2759d0e 49a2c57 0c48d27 49a2c57 0c48d27 49a2c57 0c48d27 49a2c57 0c48d27 f147a22 0c48d27 d7e6465 2759d0e dad1fea 0c48d27 2759d0e 49a2c57 0c48d27 2759d0e 49a2c57 7735f05 49a2c57 0c48d27 7735f05 49a2c57 0c48d27 49a2c57 2759d0e 7735f05 f147a22 7735f05 f147a22 7735f05 0c48d27 f147a22 fd81c83 5c16577 f147a22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
# Import the libraries
import numpy as np
import pandas as pd
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.applications.convnext import preprocess_input
import gradio as gr
# Load the model
model = load_model('models/ConvNeXtBase_80_tresh_spp.tf')
# Load the taxonomy .csv
taxo_df = pd.read_csv('taxonomy/taxonomy_mapping.csv')
taxo_df['species'] = taxo_df['species'].str.replace('_', ' ')
# Available taxonomic levels
taxonomic_levels = ['species', 'genus', 'family', 'order', 'class']
# Function to map predicted class index to class name at the selected taxonomic level
def get_class_name(predicted_class, taxonomic_level):
unique_labels = sorted(taxo_df[taxonomic_level].unique())
return unique_labels[predicted_class]
# Function to aggregate predictions to a higher taxonomic level
def aggregate_predictions(predicted_probs, taxonomic_level, class_names):
unique_labels = sorted(taxo_df[taxonomic_level].unique())
aggregated_predictions = np.zeros((predicted_probs.shape[0], len(unique_labels)))
for idx, row in taxo_df.iterrows():
species = row['species']
higher_level = row[taxonomic_level]
species_index = class_names.index(species) # Index of the species in the prediction array
higher_level_index = unique_labels.index(higher_level)
aggregated_predictions[:, higher_level_index] += predicted_probs[:, species_index]
return aggregated_predictions, unique_labels
# Function to load and preprocess the image
def load_and_preprocess_image(image, target_size=(224, 224)):
# Resize the image
img_array = img_to_array(image.resize(target_size))
# Expand the dimensions to match model input
img_array = np.expand_dims(img_array, axis=0)
# Preprocess the image
img_array = preprocess_input(img_array)
return img_array
# Function to make predictions
def make_prediction(image, taxonomic_level):
# Preprocess the image
img_array = load_and_preprocess_image(image)
# Get the class names from the 'species' column
class_names = sorted(taxo_df['species'].unique()) # Add this line to define class_names
# Make a prediction
prediction = model.predict(img_array)
# Aggregate predictions based on the selected taxonomic level
aggregated_predictions, aggregated_class_labels = aggregate_predictions(prediction, taxonomic_level, class_names)
# Get the top 5 predictions
top_indices = np.argsort(aggregated_predictions[0])[-5:][::-1]
# Get predicted class for the top prediction
predicted_class_index = np.argmax(aggregated_predictions)
predicted_class_name = aggregated_class_labels[predicted_class_index]
# Check if common name should be displayed (only at species level)
if taxonomic_level == "species":
predicted_common_name = taxo_df[taxo_df[taxonomic_level] == predicted_class_name]['common_name'].values[0]
output_text = f"<h1 style='font-weight: bold;'><span style='font-style: italic;'>{predicted_class_name}</span> ({predicted_common_name})</h1>"
else:
output_text = f"<h1 style='font-weight: bold;'>{predicted_class_name}</h1>"
# Add the top 5 predictions
output_text += "<h4 style='font-weight: bold; font-size: 1.2em;'>Top 5 Predictions:</h4>"
for i in top_indices:
class_name = aggregated_class_labels[i]
if taxonomic_level == "species":
# Display common names only at species level and make it italic
common_name = taxo_df[taxo_df[taxonomic_level] == class_name]['common_name'].values[0]
confidence_percentage = aggregated_predictions[0][i] * 100
output_text += f"<div style='display: flex; justify-content: space-between;'>" \
f"<span style='font-style: italic;'>{class_name}</span> (<span>{common_name}</span>)" \
f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
else:
# No common names at higher taxonomic levels
confidence_percentage = aggregated_predictions[0][i] * 100
output_text += f"<div style='display: flex; justify-content: space-between;'>" \
f"<span>{class_name}</span>" \
f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
return output_text
# Define a function to update the welcome message based on the logged-in user
def update_message(request: gr.Request):
return f"Welcome to the demo, Dr. {request.username}!"
# Define the Gradio interface
with gr.Blocks() as demo:
# Add a Markdown component for displaying the welcome message
welcome_message = gr.Markdown()
# Load the update_message function to display the welcome message
demo.load(update_message, None, welcome_message)
# Define the main interface for predictions
interface = gr.Interface(
fn=make_prediction, # Function to be called for predictions
inputs=[gr.Image(type="pil"), # Input type: Image (PIL format)
gr.Dropdown(choices=taxonomic_levels, label="Taxonomic level", value="species")], # Use 'value' instead of 'default'
outputs="html", # Output type: HTML for formatting
title="Amazon arboreal species classification",
description="Upload an image and select the taxonomic level to classify the species."
)
# Add the prediction interface to the main demo
interface.render()
# Launch the Gradio interface with authentication for the specified users
demo.launch(auth=[
("Luca Santini", "lucasantini"),
("Ana Benítez López", "anabenítezlópez")
]) |