# Import the libraries
import numpy as np
import pandas as pd
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.applications.convnext import preprocess_input
import gradio as gr

# Load the model
model = load_model('models/TropiCam-AI_ConvNeXtBase')

# Load the taxonomy .csv
taxo_df = pd.read_csv('taxonomy/taxonomy_mapping.csv')
taxo_df['species'] = taxo_df['species'].str.replace('_', ' ')

# Available taxonomic levels for prediction
taxonomic_levels = ['species', 'genus', 'family', 'order', 'class']

# Function to map predicted class index to class name at the selected taxonomic level
def get_class_name(predicted_class, taxonomic_level):
    unique_labels = sorted(taxo_df[taxonomic_level].unique())
    return unique_labels[predicted_class]

# Function to aggregate predictions to a higher taxonomic level
def aggregate_predictions(predicted_probs, taxonomic_level, class_names):
    unique_labels = sorted(taxo_df[taxonomic_level].unique())
    aggregated_predictions = np.zeros((predicted_probs.shape[0], len(unique_labels)))

    for idx, row in taxo_df.iterrows():
        species = row['species']
        higher_level = row[taxonomic_level]
        
        species_index = class_names.index(species)  # Index of the species in the prediction array
        higher_level_index = unique_labels.index(higher_level)
        
        aggregated_predictions[:, higher_level_index] += predicted_probs[:, species_index]
    
    return aggregated_predictions, unique_labels

# Function to load and preprocess the image
def load_and_preprocess_image(image, target_size=(224, 224)):
    # Resize the image
    img_array = img_to_array(image.resize(target_size))
    # Expand the dimensions to match model input
    img_array = np.expand_dims(img_array, axis=0)
    # Preprocess the image
    img_array = preprocess_input(img_array)
    return img_array

# Function to make predictions
def make_prediction(image, taxonomic_decision, taxonomic_level):
    # Preprocess the image
    img_array = load_and_preprocess_image(image)

    # Get the class names from the 'species' column
    class_names = sorted(taxo_df['species'].unique())

    # Make a prediction
    prediction = model.predict(img_array)

    # Initialize variables for aggregated predictions and level index
    aggregated_predictions = None
    current_level_index = 0  # Start from the species level

    # Determine the initial taxonomic level based on the user's decision
    if taxonomic_decision == "No, I will let the model decide":
        current_level_index = 0  # Start at species level if letting the model decide
    else:
        current_level_index = taxonomic_levels.index(taxonomic_level)  # Use specified level

    # Aggregate predictions based on the current taxonomic level
    aggregated_predictions, aggregated_class_labels = aggregate_predictions(prediction, taxonomic_levels[current_level_index], class_names)

    # If the user specified a taxonomic level, simply get the highest prediction at that level
    if taxonomic_decision == "Yes, I want to specify the taxonomic level":
        # Get the predicted class index for the current level
        predicted_class_index = np.argmax(aggregated_predictions)
        predicted_class_name = aggregated_class_labels[predicted_class_index]

        # Check if common name should be displayed (only at species level)
        if taxonomic_levels[current_level_index] == "species":
            predicted_common_name = taxo_df[taxo_df[taxonomic_levels[current_level_index]] == predicted_class_name]['common_name'].values[0]
            output_text = f"<h1 style='font-weight: bold;'><span style='font-style: italic;'>{predicted_class_name}</span> ({predicted_common_name})</h1>"
        else:
            output_text = f"<h1 style='font-weight: bold;'>{predicted_class_name}</h1>"
        
        # Add the top 5 predictions
        output_text += "<h4 style='font-weight: bold; font-size: 1.2em;'>Top-5 predictions:</h4>"

        top_indices = np.argsort(aggregated_predictions[0])[-5:][::-1]  # Get top 5 predictions

        for i in top_indices:
            class_name = aggregated_class_labels[i]
            confidence_percentage = aggregated_predictions[0][i] * 100
            output_text += f"<div style='display: flex; justify-content: space-between;'>" \
                           f"<span>{class_name}</span>" \
                           f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"

        return output_text

    # Confidence checking for the automatic model decision
    # Loop through taxonomic levels if the user lets the model decide
    while current_level_index < len(taxonomic_levels):
        # Aggregate predictions for the next level
        aggregated_predictions, aggregated_class_labels = aggregate_predictions(prediction, taxonomic_levels[current_level_index], class_names)
        
        # Check if the confidence of the top prediction meets the threshold
        top_prediction_index = np.argmax(aggregated_predictions)
        top_prediction_confidence = aggregated_predictions[0][top_prediction_index]

        if top_prediction_confidence >= 0.75:
            break  # Confidence threshold met, exit loop

        current_level_index += 1  # Move to the next taxonomic level

    # Check if a valid prediction was made
    if current_level_index == len(taxonomic_levels):
        return "<h1 style='font-weight: bold;'>Unknown animal</h1>"  # No valid predictions met the confidence criteria

    # Get the predicted class name for the top prediction
    predicted_class_index = np.argmax(aggregated_predictions)
    predicted_class_name = aggregated_class_labels[predicted_class_index]

    # Check if common name should be displayed (only at species level)
    if taxonomic_levels[current_level_index] == "species":
        predicted_common_name = taxo_df[taxo_df[taxonomic_levels[current_level_index]] == predicted_class_name]['common_name'].values[0]
        output_text = f"<h1 style='font-weight: bold;'><span style='font-style: italic;'>{predicted_class_name}</span> ({predicted_common_name})</h1>"
    else:
        output_text = f"<h1 style='font-weight: bold;'>{predicted_class_name}</h1>"

    # Add the top 5 predictions
    output_text += "<h4 style='font-weight: bold; font-size: 1.2em;'>Top-5 predictions:</h4>"
    
    top_indices = np.argsort(aggregated_predictions[0])[-5:][::-1]  # Get top 5 predictions

    for i in top_indices:
        class_name = aggregated_class_labels[i]
        
        if taxonomic_levels[current_level_index] == "species":
            # Display common names only at species level and make it italic
            common_name = taxo_df[taxo_df[taxonomic_levels[current_level_index]] == class_name]['common_name'].values[0]
            confidence_percentage = aggregated_predictions[0][i] * 100
            output_text += f"<div style='display: flex; justify-content: space-between;'>" \
                           f"<span style='font-style: italic;'>{class_name}</span>&nbsp;(<span>{common_name}</span>)" \
                           f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
        else:
            # No common names at higher taxonomic levels
            confidence_percentage = aggregated_predictions[0][i] * 100
            output_text += f"<div style='display: flex; justify-content: space-between;'>" \
                           f"<span>{class_name}</span>" \
                           f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
    
    return output_text

    # Confidence checking for the automatic model decision
    # Loop through taxonomic levels if the user lets the model decide
    while current_level_index < len(taxonomic_levels):
        # Aggregate predictions for the next level
        aggregated_predictions, aggregated_class_labels = aggregate_predictions(prediction, taxonomic_levels[current_level_index], class_names)
        
        # Check if the confidence of the top prediction meets the threshold
        top_prediction_index = np.argmax(aggregated_predictions)
        top_prediction_confidence = aggregated_predictions[0][top_prediction_index]

        if top_prediction_confidence >= 0.75:
            break  # Confidence threshold met, exit loop

        current_level_index += 1  # Move to the next taxonomic level

    # Check if a valid prediction was made
    if current_level_index == len(taxonomic_levels):
        return "<h1 style='font-weight: bold;'>Unknown animal</h1>"  # No valid predictions met the confidence criteria

    # Get the predicted class name for the top prediction
    predicted_class_index = np.argmax(aggregated_predictions)
    predicted_class_name = aggregated_class_labels[predicted_class_index]

    # Check if common name should be displayed (only at species level)
    if taxonomic_levels[current_level_index] == "species":
        predicted_common_name = taxo_df[taxo_df[taxonomic_levels[current_level_index]] == predicted_class_name]['common_name'].values[0]
        output_text = f"<h1 style='font-weight: bold;'><span style='font-style: italic;'>{predicted_class_name}</span> ({predicted_common_name})</h1>"
    else:
        output_text = f"<h1 style='font-weight: bold;'>{predicted_class_name}</h1>"

    # Add the top-5 predictions
    output_text += "<h4 style='font-weight: bold; font-size: 1.2em;'>Top-5 predictions:</h4>"
    
    top_indices = np.argsort(aggregated_predictions[0])[-5:][::-1]  # Get top 5 predictions

    for i in top_indices:
        class_name = aggregated_class_labels[i]
        
        if taxonomic_levels[current_level_index] == "species":
            # Display common names only at species level and make it italic
            common_name = taxo_df[taxo_df[taxonomic_levels[current_level_index]] == class_name]['common_name'].values[0]
            confidence_percentage = aggregated_predictions[0][i] * 100
            output_text += f"<div style='display: flex; justify-content: space-between;'>" \
                           f"<span style='font-style: italic;'>{class_name}</span>&nbsp;(<span>{common_name}</span>)" \
                           f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
        else:
            # No common names at higher taxonomic levels
            confidence_percentage = aggregated_predictions[0][i] * 100
            output_text += f"<div style='display: flex; justify-content: space-between;'>" \
                           f"<span>{class_name}</span>" \
                           f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
    
    return output_text

# Define the Gradio interface
interface = gr.Interface(
    fn=make_prediction,  # Function to be called for predictions
    inputs=[
        gr.Image(type="pil", label="Upload Image"),  # Input type: Image (PIL format)
        gr.Radio(choices=["Yes, I want to specify the taxonomic level", "No, I will let the model decide"],
                 label="Do you want to specify the taxonomic resolution for predictions? If you select 'No', the 'Taxonomic level' drop-down menu will be bypassed.",
                 value="No, I will let the model decide"),  # Radio button for taxonomic resolution choice
        gr.Dropdown(choices=taxonomic_levels, label="Taxonomic level:", value="species")  # Dropdown for taxonomic level
    ],
    outputs="html",  # Output type: HTML for formatting
    title="Neotropical arboreal species classification",
    description="Upload an image and our AI will classify the animal. NOTE: it's best not to feed the whole image but just the cropped animal (in the final model this will be done automatically)."
)

# Launch the Gradio interface with authentication for the specified users
interface.launch()