andrewzamp commited on
Commit
d33d3aa
·
1 Parent(s): 9669b45

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -52
app.py CHANGED
@@ -30,19 +30,10 @@ def aggregate_predictions(predicted_probs, taxonomic_level, class_names):
30
  species = row['species']
31
  higher_level = row[taxonomic_level]
32
 
33
- if species in class_names: # Check if species exists in class names
34
- species_index = class_names.index(species) # Index of the species in the prediction array
35
-
36
- if higher_level in unique_labels: # Check if higher level exists
37
- higher_level_index = unique_labels.index(higher_level)
38
-
39
- # Only update if indices are valid
40
- if species_index < predicted_probs.shape[1] and higher_level_index < aggregated_predictions.shape[1]:
41
- if predicted_probs[:, species_index].max() >= 0.80: # Check confidence level
42
- aggregated_predictions[:, higher_level_index] += predicted_probs[:, species_index]
43
- else:
44
- # Stop aggregation at the current level if confidence is below 0.80
45
- break
46
 
47
  return aggregated_predictions, unique_labels
48
 
@@ -57,51 +48,26 @@ def load_and_preprocess_image(image, target_size=(224, 224)):
57
  return img_array
58
 
59
  # Function to make predictions
60
- def make_prediction(image, taxonomic_resolution_choice, taxonomic_level):
61
  # Preprocess the image
62
  img_array = load_and_preprocess_image(image)
63
-
64
  # Get the class names from the 'species' column
65
- class_names = sorted(taxo_df['species'].unique())
66
-
67
  # Make a prediction
68
  prediction = model.predict(img_array)
69
-
70
- # Initialize the current taxonomic level index based on the user selection
71
- current_taxonomic_level_index = taxonomic_levels.index(taxonomic_level)
72
-
73
- # If the user chose to let the model decide, check the confidence levels
74
- if taxonomic_resolution_choice == "No, I will let the model decide":
75
- aggregated_predictions = prediction
76
- while current_taxonomic_level_index < len(taxonomic_levels):
77
- # Aggregate predictions based on the current taxonomic level
78
- aggregated_predictions, aggregated_class_labels = aggregate_predictions(
79
- aggregated_predictions, taxonomic_levels[current_taxonomic_level_index], class_names
80
- )
81
-
82
- # Check if the max confidence in the aggregated predictions is >= 0.80
83
- if np.max(aggregated_predictions) >= 0.80:
84
- break
85
-
86
- # Move to the next higher taxonomic level
87
- current_taxonomic_level_index += 1
88
-
89
- # Ensure we don't go out of bounds
90
- if current_taxonomic_level_index < len(taxonomic_levels):
91
- taxonomic_level = taxonomic_levels[current_taxonomic_level_index]
92
- else:
93
- taxonomic_level = taxonomic_levels[-1] # fallback to the highest taxonomic level
94
- else:
95
- # Aggregate predictions based on the selected taxonomic level
96
- aggregated_predictions, aggregated_class_labels = aggregate_predictions(prediction, taxonomic_level, class_names)
97
-
98
  # Get the top 5 predictions
99
  top_indices = np.argsort(aggregated_predictions[0])[-5:][::-1]
100
-
101
  # Get predicted class for the top prediction
102
  predicted_class_index = np.argmax(aggregated_predictions)
103
  predicted_class_name = aggregated_class_labels[predicted_class_index]
104
-
105
  # Check if common name should be displayed (only at species level)
106
  if taxonomic_level == "species":
107
  predicted_common_name = taxo_df[taxo_df[taxonomic_level] == predicted_class_name]['common_name'].values[0]
@@ -111,10 +77,10 @@ def make_prediction(image, taxonomic_resolution_choice, taxonomic_level):
111
 
112
  # Add the top 5 predictions
113
  output_text += "<h4 style='font-weight: bold; font-size: 1.2em;'>Top 5 Predictions:</h4>"
114
-
115
  for i in top_indices:
116
  class_name = aggregated_class_labels[i]
117
-
118
  if taxonomic_level == "species":
119
  # Display common names only at species level and make it italic
120
  common_name = taxo_df[taxo_df[taxonomic_level] == class_name]['common_name'].values[0]
@@ -128,7 +94,7 @@ def make_prediction(image, taxonomic_resolution_choice, taxonomic_level):
128
  output_text += f"<div style='display: flex; justify-content: space-between;'>" \
129
  f"<span>{class_name}</span>" \
130
  f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
131
-
132
  return output_text
133
 
134
  # Define the Gradio interface
@@ -143,7 +109,7 @@ interface = gr.Interface(
143
  ],
144
  outputs="html", # Output type: HTML for formatting
145
  title="Amazon arboreal species classification",
146
- description="Upload an image and select the taxonomic level to classify the species."
147
  )
148
 
149
  # Launch the Gradio interface with authentication for the specified users
 
30
  species = row['species']
31
  higher_level = row[taxonomic_level]
32
 
33
+ species_index = class_names.index(species) # Index of the species in the prediction array
34
+ higher_level_index = unique_labels.index(higher_level)
35
+
36
+ aggregated_predictions[:, higher_level_index] += predicted_probs[:, species_index]
 
 
 
 
 
 
 
 
 
37
 
38
  return aggregated_predictions, unique_labels
39
 
 
48
  return img_array
49
 
50
  # Function to make predictions
51
+ def make_prediction(image, taxonomic_level):
52
  # Preprocess the image
53
  img_array = load_and_preprocess_image(image)
54
+
55
  # Get the class names from the 'species' column
56
+ class_names = sorted(taxo_df['species'].unique()) # Add this line to define class_names
57
+
58
  # Make a prediction
59
  prediction = model.predict(img_array)
60
+
61
+ # Aggregate predictions based on the selected taxonomic level
62
+ aggregated_predictions, aggregated_class_labels = aggregate_predictions(prediction, taxonomic_level, class_names)
63
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  # Get the top 5 predictions
65
  top_indices = np.argsort(aggregated_predictions[0])[-5:][::-1]
66
+
67
  # Get predicted class for the top prediction
68
  predicted_class_index = np.argmax(aggregated_predictions)
69
  predicted_class_name = aggregated_class_labels[predicted_class_index]
70
+
71
  # Check if common name should be displayed (only at species level)
72
  if taxonomic_level == "species":
73
  predicted_common_name = taxo_df[taxo_df[taxonomic_level] == predicted_class_name]['common_name'].values[0]
 
77
 
78
  # Add the top 5 predictions
79
  output_text += "<h4 style='font-weight: bold; font-size: 1.2em;'>Top 5 Predictions:</h4>"
80
+
81
  for i in top_indices:
82
  class_name = aggregated_class_labels[i]
83
+
84
  if taxonomic_level == "species":
85
  # Display common names only at species level and make it italic
86
  common_name = taxo_df[taxo_df[taxonomic_level] == class_name]['common_name'].values[0]
 
94
  output_text += f"<div style='display: flex; justify-content: space-between;'>" \
95
  f"<span>{class_name}</span>" \
96
  f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
97
+
98
  return output_text
99
 
100
  # Define the Gradio interface
 
109
  ],
110
  outputs="html", # Output type: HTML for formatting
111
  title="Amazon arboreal species classification",
112
+ description="Upload an image and select the taxonomic level (optional) to classify the species."
113
  )
114
 
115
  # Launch the Gradio interface with authentication for the specified users