File size: 7,016 Bytes
9f198ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d33d3aa
 
 
 
9f198ef
 
 
 
 
 
 
 
 
 
 
 
 
 
ad3d1fe
c80713a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68bd012
c80713a
 
 
68bd012
c80713a
 
 
 
e014372
22a6037
e493bfa
 
22a6037
c80713a
 
 
 
 
 
 
 
 
 
 
 
098b334
c80713a
 
e014372
c80713a
 
e014372
c80713a
 
 
 
 
 
 
 
 
 
 
 
 
d33d3aa
c80713a
9f198ef
1364165
 
e93a56e
 
 
 
96a0181
e93a56e
68bd012
e93a56e
 
1364165
d33d3aa
1364165
f147a22
 
1364165
 
9f198ef
50d52be
f147a22
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# Import the libraries
import numpy as np
import pandas as pd
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.applications.convnext import preprocess_input
import gradio as gr

# Load the model
model = load_model('models/ConvNeXtBase_80_tresh_spp.tf')

# Load the taxonomy .csv
taxo_df = pd.read_csv('taxonomy/taxonomy_mapping.csv')
taxo_df['species'] = taxo_df['species'].str.replace('_', ' ')

# Available taxonomic levels
taxonomic_levels = ['species', 'genus', 'family', 'order', 'class']

# Function to map predicted class index to class name at the selected taxonomic level
def get_class_name(predicted_class, taxonomic_level):
    unique_labels = sorted(taxo_df[taxonomic_level].unique())
    return unique_labels[predicted_class]

# Function to aggregate predictions to a higher taxonomic level
def aggregate_predictions(predicted_probs, taxonomic_level, class_names):
    unique_labels = sorted(taxo_df[taxonomic_level].unique())
    aggregated_predictions = np.zeros((predicted_probs.shape[0], len(unique_labels)))

    for idx, row in taxo_df.iterrows():
        species = row['species']
        higher_level = row[taxonomic_level]
        
        species_index = class_names.index(species)  # Index of the species in the prediction array
        higher_level_index = unique_labels.index(higher_level)
        
        aggregated_predictions[:, higher_level_index] += predicted_probs[:, species_index]
    
    return aggregated_predictions, unique_labels

# Function to load and preprocess the image
def load_and_preprocess_image(image, target_size=(224, 224)):
    # Resize the image
    img_array = img_to_array(image.resize(target_size))
    # Expand the dimensions to match model input
    img_array = np.expand_dims(img_array, axis=0)
    # Preprocess the image
    img_array = preprocess_input(img_array)
    return img_array

# Function to make predictions
def make_prediction(image, taxonomic_decision, taxonomic_level):
    # Preprocess the image
    img_array = load_and_preprocess_image(image)
    
    # Get the class names from the 'species' column
    class_names = sorted(taxo_df['species'].unique())
    
    # Make a prediction
    prediction = model.predict(img_array)
    
    # Initialize variables for aggregated predictions and level index
    aggregated_predictions = None
    current_level_index = 0  # Start from the species level
    
    # Determine the initial taxonomic level based on the user's decision
    if taxonomic_decision == "No, I will let the model decide":
        current_level_index = 0  # Start at species level if letting the model decide
    else:
        current_level_index = taxonomic_levels.index(taxonomic_level)  # Use specified level

    # Loop through taxonomic levels to check confidence
    while current_level_index < len(taxonomic_levels):
        # Aggregate predictions based on the current taxonomic level
        aggregated_predictions, aggregated_class_labels = aggregate_predictions(prediction, taxonomic_levels[current_level_index], class_names)
        
        # Check if the confidence of the top prediction meets the threshold
        top_prediction_index = np.argmax(aggregated_predictions)
        top_prediction_confidence = aggregated_predictions[0][top_prediction_index]
        
        if top_prediction_confidence >= 0.80:
            break  # Confidence threshold met, exit loop
        
        current_level_index += 1  # Move to the next taxonomic level

    # Check if a valid prediction was made
    if current_level_index == len(taxonomic_levels):
        return "<h1 style='font-weight: bold;'>Unknown animal</h1>"  # No valid predictions met the confidence criteria
    
    # Get the predicted class name for the top prediction
    predicted_class_index = np.argmax(aggregated_predictions)
    predicted_class_name = aggregated_class_labels[predicted_class_index]
    
    # Check if common name should be displayed (only at species level)
    if taxonomic_levels[current_level_index] == "species":
        predicted_common_name = taxo_df[taxo_df[taxonomic_levels[current_level_index]] == predicted_class_name]['common_name'].values[0]
        output_text = f"<h1 style='font-weight: bold;'><span style='font-style: italic;'>{predicted_class_name}</span> ({predicted_common_name})</h1>"
    else:
        output_text = f"<h1 style='font-weight: bold;'>{predicted_class_name}</h1>"

    # Add the top 5 predictions
    output_text += "<h4 style='font-weight: bold; font-size: 1.2em;'>Top-5 predictions:</h4>"
    
    top_indices = np.argsort(aggregated_predictions[0])[-5:][::-1]  # Get top 5 predictions

    for i in top_indices:
        class_name = aggregated_class_labels[i]
        
        if taxonomic_levels[current_level_index] == "species":
            # Display common names only at species level and make it italic
            common_name = taxo_df[taxo_df[taxonomic_levels[current_level_index]] == class_name]['common_name'].values[0]
            confidence_percentage = aggregated_predictions[0][i] * 100
            output_text += f"<div style='display: flex; justify-content: space-between;'>" \
                           f"<span style='font-style: italic;'>{class_name}</span>&nbsp;(<span>{common_name}</span>)" \
                           f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
        else:
            # No common names at higher taxonomic levels
            confidence_percentage = aggregated_predictions[0][i] * 100
            output_text += f"<div style='display: flex; justify-content: space-between;'>" \
                           f"<span>{class_name}</span>" \
                           f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
    
    return output_text

# Define the Gradio interface
interface = gr.Interface(
    fn=make_prediction,  # Function to be called for predictions
    inputs=[
        gr.Image(type="pil", label="Upload Image"),  # Input type: Image (PIL format)
        gr.Radio(choices=["Yes, I want to specify the taxonomic level", "No, I will let the model decide"],
                 label="Do you want to specify the taxonomic resolution for predictions? If you select 'No', the 'Taxonomic level' drop-down menu will be bypassed.",
                 value="No, I will let the model decide"),  # Radio button for taxonomic resolution choice
        gr.Dropdown(choices=taxonomic_levels, label="Taxonomic level:", value="species")  # Dropdown for taxonomic level
    ],
    outputs="html",  # Output type: HTML for formatting
    title="Amazon arboreal species classification",
    description="Upload an image and select the taxonomic level (optional) to classify the species."
)

# Launch the Gradio interface with authentication for the specified users
interface.launch(auth=[
    ("Andrea Zampetti", "andreazampetti"),
    ("Luca Santini", "lucasantini"),
    ("Ana Benítez", "anabenítez")
])