Spaces:
Sleeping
Sleeping
# Import the libraries | |
import numpy as np | |
import pandas as pd | |
from tensorflow.keras.layers import TFSMLayer # type: ignore | |
from tensorflow.keras.preprocessing.image import load_img, img_to_array # type: ignore | |
from tensorflow.keras.applications.convnext import preprocess_input # type: ignore | |
import gradio as gr | |
# Load the model | |
model = TFSMLayer('models/ConvNeXtBase_80_tresh_spp.tf', call_endpoint='serving_default') | |
# Load the taxonomy .csv | |
taxo_df = pd.read_csv('taxonomy/taxonomy_mapping.csv', sep=';') | |
taxo_df['species'] = taxo_df['species'].str.replace('_', ' ') | |
# Extract unique class names from the 'species' column | |
class_names = sorted(taxo_df['species'].unique()) | |
# Function to map predicted class index to class name | |
def get_class_name(predicted_class): | |
return class_names[predicted_class] | |
# Function to load and preprocess the image | |
def load_and_preprocess_image(image, target_size=(224, 224)): | |
# Resize the image (assuming image is a PIL image) | |
img_array = img_to_array(image.resize(target_size)) | |
# Expand the dimensions of the array to match model input | |
img_array = np.expand_dims(img_array, axis=0) | |
# Preprocess using the appropriate function (for example, ResNet50) | |
img_array = preprocess_input(img_array) | |
return img_array | |
# Function to make predictions | |
def make_prediction(image): | |
# Preprocess the image | |
img_array = load_and_preprocess_image(image) | |
# Make a prediction | |
prediction = model.predict(img_array) | |
# Get the top 5 predictions | |
top_indices = np.argsort(prediction[0])[-5:][::-1] # Get indices of top 5 classes | |
# Get predicted class and common name for the top prediction | |
predicted_class_index = np.argmax(prediction) | |
predicted_class_name = get_class_name(predicted_class_index) | |
predicted_common_name = taxo_df[taxo_df['species'] == predicted_class_name]['common_name'].values[0] # Get common name | |
confidence = prediction[0][predicted_class_index] * 100 # Confidence of the predicted class | |
# Create output text with HTML formatting | |
output_text = f"<h1 style='font-weight: bold;'><span style='font-style: italic;'>{predicted_class_name}</span> ({predicted_common_name})</h1>" # Large bold for predicted class, italic for class name | |
output_text += "<h4 style='font-weight: bold; font-size: 1.2em;'>Top 5 Predictions:</h4>" # Bold and larger font for predictions | |
for i in top_indices: | |
class_name = get_class_name(i) | |
common_name = taxo_df[taxo_df['species'] == class_name]['common_name'].values[0] # Get common name from CSV | |
confidence_percentage = prediction[0][i] * 100 | |
# Format the output with space between class name and common name | |
output_text += f"<div style='display: flex; justify-content: space-between;'>" \ | |
f"<span style='font-style: italic;'>{class_name}</span> (<span>{common_name}</span>)" \ | |
f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>" | |
return output_text | |
# Define the Gradio interface | |
interface = gr.Interface( | |
fn=make_prediction, # Function to be called for predictions | |
inputs=gr.Image(type="pil"), # Input type: Image (PIL format) | |
outputs="html", # Output type: HTML for formatting | |
title="Amazon arboreal species classification", | |
description="Upload an image to classify the species." | |
) | |
# Launch the Gradio interface | |
interface.launch() |