File size: 5,672 Bytes
9f198ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1364165
 
e93a56e
 
 
 
4fe087a
e93a56e
 
 
 
1364165
 
 
f147a22
 
1364165
 
9f198ef
f147a22
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Import the libraries
import numpy as np
import pandas as pd
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.applications.convnext import preprocess_input
import gradio as gr

# Load the model
model = load_model('models/ConvNeXtBase_80_tresh_spp.tf')

# Load the taxonomy .csv
taxo_df = pd.read_csv('taxonomy/taxonomy_mapping.csv')
taxo_df['species'] = taxo_df['species'].str.replace('_', ' ')

# Available taxonomic levels
taxonomic_levels = ['species', 'genus', 'family', 'order', 'class']

# Function to map predicted class index to class name at the selected taxonomic level
def get_class_name(predicted_class, taxonomic_level):
    unique_labels = sorted(taxo_df[taxonomic_level].unique())
    return unique_labels[predicted_class]

# Function to aggregate predictions to a higher taxonomic level
def aggregate_predictions(predicted_probs, taxonomic_level, class_names):
    unique_labels = sorted(taxo_df[taxonomic_level].unique())
    aggregated_predictions = np.zeros((predicted_probs.shape[0], len(unique_labels)))

    for idx, row in taxo_df.iterrows():
        species = row['species']
        higher_level = row[taxonomic_level]
        
        species_index = class_names.index(species)  # Index of the species in the prediction array
        higher_level_index = unique_labels.index(higher_level)
        
        aggregated_predictions[:, higher_level_index] += predicted_probs[:, species_index]
    
    return aggregated_predictions, unique_labels

# Function to load and preprocess the image
def load_and_preprocess_image(image, target_size=(224, 224)):
    # Resize the image
    img_array = img_to_array(image.resize(target_size))
    # Expand the dimensions to match model input
    img_array = np.expand_dims(img_array, axis=0)
    # Preprocess the image
    img_array = preprocess_input(img_array)
    return img_array

# Function to make predictions
def make_prediction(image, taxonomic_level):
    # Preprocess the image
    img_array = load_and_preprocess_image(image)
    
    # Get the class names from the 'species' column
    class_names = sorted(taxo_df['species'].unique())  # Add this line to define class_names
    
    # Make a prediction
    prediction = model.predict(img_array)
    
    # Aggregate predictions based on the selected taxonomic level
    aggregated_predictions, aggregated_class_labels = aggregate_predictions(prediction, taxonomic_level, class_names)
    
    # Get the top 5 predictions
    top_indices = np.argsort(aggregated_predictions[0])[-5:][::-1]
    
    # Get predicted class for the top prediction
    predicted_class_index = np.argmax(aggregated_predictions)
    predicted_class_name = aggregated_class_labels[predicted_class_index]
    
    # Check if common name should be displayed (only at species level)
    if taxonomic_level == "species":
        predicted_common_name = taxo_df[taxo_df[taxonomic_level] == predicted_class_name]['common_name'].values[0]
        output_text = f"<h1 style='font-weight: bold;'><span style='font-style: italic;'>{predicted_class_name}</span> ({predicted_common_name})</h1>"
    else:
        output_text = f"<h1 style='font-weight: bold;'>{predicted_class_name}</h1>"

    # Add the top 5 predictions
    output_text += "<h4 style='font-weight: bold; font-size: 1.2em;'>Top 5 Predictions:</h4>"
    
    for i in top_indices:
        class_name = aggregated_class_labels[i]
        
        if taxonomic_level == "species":
            # Display common names only at species level and make it italic
            common_name = taxo_df[taxo_df[taxonomic_level] == class_name]['common_name'].values[0]
            confidence_percentage = aggregated_predictions[0][i] * 100
            output_text += f"<div style='display: flex; justify-content: space-between;'>" \
                           f"<span style='font-style: italic;'>{class_name}</span>&nbsp;(<span>{common_name}</span>)" \
                           f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
        else:
            # No common names at higher taxonomic levels
            confidence_percentage = aggregated_predictions[0][i] * 100
            output_text += f"<div style='display: flex; justify-content: space-between;'>" \
                           f"<span>{class_name}</span>" \
                           f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
    
    return output_text

# Define the Gradio interface
interface = gr.Interface(
    fn=make_prediction,  # Function to be called for predictions
    inputs=[
        gr.Image(type="pil", label="Upload Image"),  # Input type: Image (PIL format)
        gr.Radio(choices=["Yes, I want to specify the taxonomic level", "No, I will let the model decide"],
                 label="Do you want to specify the taxonomic resolution for predictions?<br><br>NOTE: if you select 'No', the next drop-down menu will be bypassed",
                 value="No, I will let the model decide"),  # Radio button for taxonomic resolution choice
        gr.Dropdown(choices=taxonomic_levels, label="Taxonomic level", value="species")  # Dropdown for taxonomic level
    ],
    outputs="html",  # Output type: HTML for formatting
    title="Amazon arboreal species classification",
    description="Upload an image and select the taxonomic level to classify the species."
)

# Launch the Gradio interface with authentication for the specified users
interface.launch(auth=[
    ("Andrea Zampetti", "andreazampetti"),
    ("Luca Santini", "lucasantini"),
    ("Ana Benítez López", "anabenítezlópez")
])