File size: 5,834 Bytes
64d14f7
 
 
49a2c57
 
 
64d14f7
 
0c48d27
34ecf7a
0c48d27
 
fb6d3d4
0c48d27
 
49a2c57
 
0c48d27
f147a22
 
 
 
 
2759d0e
 
49a2c57
 
 
 
 
 
 
 
2759d0e
49a2c57
 
 
 
0c48d27
 
 
49a2c57
0c48d27
49a2c57
0c48d27
49a2c57
0c48d27
 
 
f147a22
 
0c48d27
 
d7e6465
 
 
 
2759d0e
dad1fea
0c48d27
2759d0e
 
49a2c57
0c48d27
2759d0e
49a2c57
7735f05
49a2c57
 
0c48d27
7735f05
 
 
 
 
 
 
 
49a2c57
0c48d27
 
49a2c57
2759d0e
7735f05
 
f147a22
7735f05
 
 
 
 
f147a22
7735f05
 
 
 
 
0c48d27
 
f147a22
 
 
fd81c83
5c16577
f147a22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# Import the libraries
import numpy as np
import pandas as pd
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.applications.convnext import preprocess_input
import gradio as gr

# Load the model
model = load_model('models/ConvNeXtBase_80_tresh_spp.tf')

# Load the taxonomy .csv
taxo_df = pd.read_csv('taxonomy/taxonomy_mapping.csv')
taxo_df['species'] = taxo_df['species'].str.replace('_', ' ')

# Available taxonomic levels
taxonomic_levels = ['species', 'genus', 'family', 'order', 'class']

# Function to map predicted class index to class name at the selected taxonomic level
def get_class_name(predicted_class, taxonomic_level):
    unique_labels = sorted(taxo_df[taxonomic_level].unique())
    return unique_labels[predicted_class]

# Function to aggregate predictions to a higher taxonomic level
def aggregate_predictions(predicted_probs, taxonomic_level, class_names):
    unique_labels = sorted(taxo_df[taxonomic_level].unique())
    aggregated_predictions = np.zeros((predicted_probs.shape[0], len(unique_labels)))

    for idx, row in taxo_df.iterrows():
        species = row['species']
        higher_level = row[taxonomic_level]
        
        species_index = class_names.index(species)  # Index of the species in the prediction array
        higher_level_index = unique_labels.index(higher_level)
        
        aggregated_predictions[:, higher_level_index] += predicted_probs[:, species_index]
    
    return aggregated_predictions, unique_labels

# Function to load and preprocess the image
def load_and_preprocess_image(image, target_size=(224, 224)):
    # Resize the image
    img_array = img_to_array(image.resize(target_size))
    # Expand the dimensions to match model input
    img_array = np.expand_dims(img_array, axis=0)
    # Preprocess the image
    img_array = preprocess_input(img_array)
    return img_array

# Function to make predictions
def make_prediction(image, taxonomic_level):
    # Preprocess the image
    img_array = load_and_preprocess_image(image)
    
    # Get the class names from the 'species' column
    class_names = sorted(taxo_df['species'].unique())  # Add this line to define class_names
    
    # Make a prediction
    prediction = model.predict(img_array)
    
    # Aggregate predictions based on the selected taxonomic level
    aggregated_predictions, aggregated_class_labels = aggregate_predictions(prediction, taxonomic_level, class_names)
    
    # Get the top 5 predictions
    top_indices = np.argsort(aggregated_predictions[0])[-5:][::-1]
    
    # Get predicted class for the top prediction
    predicted_class_index = np.argmax(aggregated_predictions)
    predicted_class_name = aggregated_class_labels[predicted_class_index]
    
    # Check if common name should be displayed (only at species level)
    if taxonomic_level == "species":
        predicted_common_name = taxo_df[taxo_df[taxonomic_level] == predicted_class_name]['common_name'].values[0]
        output_text = f"<h1 style='font-weight: bold;'><span style='font-style: italic;'>{predicted_class_name}</span> ({predicted_common_name})</h1>"
    else:
        output_text = f"<h1 style='font-weight: bold;'>{predicted_class_name}</h1>"

    # Add the top 5 predictions
    output_text += "<h4 style='font-weight: bold; font-size: 1.2em;'>Top 5 Predictions:</h4>"
    
    for i in top_indices:
        class_name = aggregated_class_labels[i]
        
        if taxonomic_level == "species":
            # Display common names only at species level and make it italic
            common_name = taxo_df[taxo_df[taxonomic_level] == class_name]['common_name'].values[0]
            confidence_percentage = aggregated_predictions[0][i] * 100
            output_text += f"<div style='display: flex; justify-content: space-between;'>" \
                           f"<span style='font-style: italic;'>{class_name}</span>&nbsp;(<span>{common_name}</span>)" \
                           f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
        else:
            # No common names at higher taxonomic levels
            confidence_percentage = aggregated_predictions[0][i] * 100
            output_text += f"<div style='display: flex; justify-content: space-between;'>" \
                           f"<span>{class_name}</span>" \
                           f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
    
    return output_text

# Define a function to update the welcome message based on the logged-in user
def update_message(request: gr.Request):
    return f"Welcome to the demo, Dr. {request.username}!"

# Define the Gradio interface
with gr.Blocks() as demo:
    # Add a Markdown component for displaying the welcome message
    welcome_message = gr.Markdown()
    
    # Load the update_message function to display the welcome message
    demo.load(update_message, None, welcome_message)

    # Define the main interface for predictions
    interface = gr.Interface(
        fn=make_prediction,            # Function to be called for predictions
        inputs=[gr.Image(type="pil"),   # Input type: Image (PIL format)
                gr.Dropdown(choices=taxonomic_levels, label="Taxonomic level", value="species")],  # Use 'value' instead of 'default'
        outputs="html",                # Output type: HTML for formatting
        title="Amazon arboreal species classification",
        description="Upload an image and select the taxonomic level to classify the species."
    )

    # Add the prediction interface to the main demo
    interface.render()

# Launch the Gradio interface with authentication for the specified users
demo.launch(auth=[
    ("Luca Santini", "lucasantini"), 
    ("Ana Benítez López", "anabenítezlópez")
])