andrewzamp commited on
Commit
49a2c57
·
1 Parent(s): fdf94ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -33
app.py CHANGED
@@ -1,9 +1,9 @@
1
  # Import the libraries
2
  import numpy as np
3
  import pandas as pd
4
- from tensorflow.keras.models import load_model # type: ignore
5
- from tensorflow.keras.preprocessing.image import load_img, img_to_array # type: ignore
6
- from tensorflow.keras.applications.convnext import preprocess_input # type: ignore
7
  import gradio as gr
8
 
9
  # Load the model
@@ -13,62 +13,83 @@ model = load_model('models/ConvNeXtBase_80_tresh_spp.tf')
13
  taxo_df = pd.read_csv('taxonomy/taxonomy_mapping.csv', sep=';')
14
  taxo_df['species'] = taxo_df['species'].str.replace('_', ' ')
15
 
16
- # Extract unique class names from the 'species' column
17
- class_names = sorted(taxo_df['species'].unique())
18
 
19
- # Function to map predicted class index to class name
20
- def get_class_name(predicted_class):
21
- return class_names[predicted_class]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  # Function to load and preprocess the image
24
  def load_and_preprocess_image(image, target_size=(224, 224)):
25
- # Resize the image (assuming image is a PIL image)
26
  img_array = img_to_array(image.resize(target_size))
27
- # Expand the dimensions of the array to match model input
28
  img_array = np.expand_dims(img_array, axis=0)
29
- # Preprocess using the appropriate function (for example, ResNet50)
30
  img_array = preprocess_input(img_array)
31
  return img_array
32
 
33
- # Function to make predictions
34
- def make_prediction(image):
35
  # Preprocess the image
36
  img_array = load_and_preprocess_image(image)
37
- # Make a prediction
38
  prediction = model.predict(img_array)
39
 
 
 
 
40
  # Get the top 5 predictions
41
- top_indices = np.argsort(prediction[0])[-5:][::-1] # Get indices of top 5 classes
 
 
 
 
42
 
43
- # Get predicted class and common name for the top prediction
44
- predicted_class_index = np.argmax(prediction)
45
- predicted_class_name = get_class_name(predicted_class_index)
46
- predicted_common_name = taxo_df[taxo_df['species'] == predicted_class_name]['common_name'].values[0] # Get common name
47
- confidence = prediction[0][predicted_class_index] * 100 # Confidence of the predicted class
48
 
49
  # Create output text with HTML formatting
50
- output_text = f"<h1 style='font-weight: bold;'><span style='font-style: italic;'>{predicted_class_name}</span> ({predicted_common_name})</h1>" # Large bold for predicted class, italic for class name
51
- output_text += "<h4 style='font-weight: bold; font-size: 1.2em;'>Top 5 Predictions:</h4>" # Bold and larger font for predictions
52
 
53
  for i in top_indices:
54
- class_name = get_class_name(i)
55
- common_name = taxo_df[taxo_df['species'] == class_name]['common_name'].values[0] # Get common name from CSV
56
- confidence_percentage = prediction[0][i] * 100
57
-
58
- # Format the output with space between class name and common name
59
  output_text += f"<div style='display: flex; justify-content: space-between;'>" \
60
  f"<span style='font-style: italic;'>{class_name}</span>&nbsp;(<span>{common_name}</span>)" \
61
  f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
62
-
63
  return output_text
64
 
65
  # Define the Gradio interface
66
  interface = gr.Interface(
67
- fn=make_prediction, # Function to be called for predictions
68
- inputs=gr.Image(type="pil"), # Input type: Image (PIL format)
69
- outputs="html", # Output type: HTML for formatting
70
- title="Amazon arboreal species classification",
71
- description="Upload an image to classify the species."
 
72
  )
73
 
74
  # Launch the Gradio interface
 
1
  # Import the libraries
2
  import numpy as np
3
  import pandas as pd
4
+ from tensorflow.keras.models import load_model
5
+ from tensorflow.keras.preprocessing.image import load_img, img_to_array
6
+ from tensorflow.keras.applications.convnext import preprocess_input
7
  import gradio as gr
8
 
9
  # Load the model
 
13
  taxo_df = pd.read_csv('taxonomy/taxonomy_mapping.csv', sep=';')
14
  taxo_df['species'] = taxo_df['species'].str.replace('_', ' ')
15
 
16
+ # Available taxonomic levels
17
+ taxonomic_levels = ['species', 'genus', 'family', 'order', 'class']
18
 
19
+ # Function to map predicted class index to class name at the selected taxonomic level
20
+ def get_class_name(predicted_class, taxonomic_level):
21
+ unique_labels = sorted(taxo_df[taxonomic_level].unique())
22
+ return unique_labels[predicted_class]
23
+
24
+ # Function to aggregate predictions to the specified taxonomic level
25
+ def aggregate_predictions(predicted_probs, taxonomic_level):
26
+ unique_labels = sorted(taxo_df[taxonomic_level].unique())
27
+ aggregated_predictions = np.zeros((predicted_probs.shape[0], len(unique_labels)))
28
+
29
+ for idx, row in taxo_df.iterrows():
30
+ species = row['species']
31
+ higher_level = row[taxonomic_level]
32
+
33
+ species_index = class_names.index(species) # Index of the species in the prediction array
34
+ higher_level_index = unique_labels.index(higher_level) # Index of the higher taxonomic level
35
+
36
+ aggregated_predictions[:, higher_level_index] += predicted_probs[:, species_index]
37
+
38
+ return aggregated_predictions, unique_labels
39
 
40
  # Function to load and preprocess the image
41
  def load_and_preprocess_image(image, target_size=(224, 224)):
42
+ # Resize the image
43
  img_array = img_to_array(image.resize(target_size))
44
+ # Expand the dimensions to match model input
45
  img_array = np.expand_dims(img_array, axis=0)
46
+ # Preprocess the image
47
  img_array = preprocess_input(img_array)
48
  return img_array
49
 
50
+ # Function to make predictions at the selected taxonomic level
51
+ def make_prediction(image, taxonomic_level):
52
  # Preprocess the image
53
  img_array = load_and_preprocess_image(image)
54
+ # Make a prediction at the species level
55
  prediction = model.predict(img_array)
56
 
57
+ # Aggregate predictions to the selected taxonomic level
58
+ aggregated_predictions, aggregated_class_labels = aggregate_predictions(prediction, taxonomic_level)
59
+
60
  # Get the top 5 predictions
61
+ top_indices = np.argsort(aggregated_predictions[0])[-5:][::-1] # Indices of top 5 predictions
62
+
63
+ # Get the predicted class and common name for the top prediction
64
+ predicted_class_index = np.argmax(aggregated_predictions)
65
+ predicted_class_name = aggregated_class_labels[predicted_class_index]
66
 
67
+ # Get common name for the top predicted class
68
+ predicted_common_name = taxo_df[taxo_df[taxonomic_level] == predicted_class_name]['common_name'].values[0]
69
+ confidence = aggregated_predictions[0][predicted_class_index] * 100 # Confidence of the predicted class
 
 
70
 
71
  # Create output text with HTML formatting
72
+ output_text = f"<h1 style='font-weight: bold;'><span style='font-style: italic;'>{predicted_class_name}</span> ({predicted_common_name})</h1>"
73
+ output_text += "<h4 style='font-weight: bold; font-size: 1.2em;'>Top 5 Predictions:</h4>"
74
 
75
  for i in top_indices:
76
+ class_name = aggregated_class_labels[i]
77
+ common_name = taxo_df[taxo_df[taxonomic_level] == class_name]['common_name'].values[0]
78
+ confidence_percentage = aggregated_predictions[0][i] * 100
 
 
79
  output_text += f"<div style='display: flex; justify-content: space-between;'>" \
80
  f"<span style='font-style: italic;'>{class_name}</span>&nbsp;(<span>{common_name}</span>)" \
81
  f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
82
+
83
  return output_text
84
 
85
  # Define the Gradio interface
86
  interface = gr.Interface(
87
+ fn=make_prediction, # Function to be called for predictions
88
+ inputs=[gr.Image(type="pil"), # Input type: Image (PIL format)
89
+ gr.Dropdown(choices=taxonomic_levels, label="Taxonomic Level", default="species")], # Dropdown for taxonomic level
90
+ outputs="html", # Output type: HTML for formatting
91
+ title="Amazon Arboreal Species Classification",
92
+ description="Upload an image and select the taxonomic level to classify the species."
93
  )
94
 
95
  # Launch the Gradio interface