CesarLeblanc commited on
Commit
4e59324
1 Parent(s): 142304a
Files changed (1) hide show
  1. app.py +48 -34
app.py CHANGED
@@ -79,37 +79,50 @@ def classification(text, k):
79
  image_output = return_habitat_image(habitat_labels[0])
80
  return text, image_output
81
 
82
- def masking(text):
83
  text = gbif_normalization(text)
84
  text_split = text.split(', ')
85
 
86
- max_score = 0
87
- best_prediction = None
88
- best_position = None
89
- best_sentence = None
 
 
 
 
 
90
 
91
- for i in range(len(text_split) + 1):
92
- masked_text = ', '.join(text_split[:i] + ['[MASK]'] + text_split[i:])
93
-
94
- j = 0
95
- while True:
96
- prediction = mask_model(masked_text)[j]
97
- species = prediction['token_str']
98
- if species in text_split:
99
- j += 1
100
- else:
101
- break
102
 
103
- score = prediction['score']
104
- sentence = prediction['sequence']
105
 
106
- if score > max_score:
107
- max_score = score
108
- best_prediction = species
109
- best_position = i
110
- best_sentence = sentence
111
-
112
- text = f"The most likely missing species is {best_prediction} (position {best_position}).\nThe new vegetation plot is {best_sentence}."
 
 
 
 
 
 
 
 
113
  image = return_species_image(best_prediction)
114
  return text, image
115
 
@@ -122,26 +135,27 @@ with gr.Blocks() as demo:
122
  with gr.Row():
123
  with gr.Column():
124
  species_classification = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.")
125
- k_classification = gr.Slider(1, 5, value=1, label="Top-k", info="Choose the number of top habitats to display.")
126
  with gr.Column():
127
- text_output_1 = gr.Textbox()
128
- text_output_2 = gr.Image()
129
  button_classification = gr.Button("Classify")
130
  gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""")
131
- gr.Examples([["sparganium erectum, calystegia sepium, persicaria amphibia", 1]], [species_classification, k_classification], [text_output_1, text_output_2], classification, True)
132
 
133
  with gr.Tab("Missing species finding"):
134
  gr.Markdown("""<h3 style="text-align: center;">Finding the missing species!</h3>""")
135
  with gr.Row():
136
  species_masking = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.")
 
137
  with gr.Column():
138
- image_output_1 = gr.Textbox()
139
- image_output_2 = gr.Image()
140
  button_masking = gr.Button("Find")
141
  gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""")
142
- gr.Examples([["vaccinium myrtillus, dryopteris dilatata, molinia caerulea"]], [species_masking], [image_output_1, image_output_2], masking, True)
143
 
144
- button_classification.click(classification, inputs=[species_classification, k_classification], outputs=[text_output_1, text_output_2])
145
- button_masking.click(masking, inputs=[species_masking], outputs=[image_output_1, image_output_2])
146
 
147
  demo.launch()
 
79
  image_output = return_habitat_image(habitat_labels[0])
80
  return text, image_output
81
 
82
+ def masking(text, k):
83
  text = gbif_normalization(text)
84
  text_split = text.split(', ')
85
 
86
+ best_predictions = []
87
+ best_positions = []
88
+ best_sentences = []
89
+
90
+ for _ in range(k):
91
+ max_score = 0
92
+ best_prediction = None
93
+ best_position = None
94
+ best_sentence = None
95
 
96
+ for i in range(len(text_split) + 1):
97
+ masked_text = ', '.join(text_split[:i] + ['[MASK]'] + text_split[i:])
98
+
99
+ j = 0
100
+ while True:
101
+ prediction = mask_model(masked_text)[j]
102
+ species = prediction['token_str']
103
+ if species in text_split or species in best_predictions:
104
+ j += 1
105
+ else:
106
+ break
107
 
108
+ score = prediction['score']
109
+ sentence = prediction['sequence']
110
 
111
+ if score > max_score:
112
+ max_score = score
113
+ best_prediction = species
114
+ best_position = i
115
+ best_sentence = sentence
116
+
117
+ best_predictions.append(best_prediction)
118
+ best_positions.append(best_position)
119
+ best_sentences.append(best_sentence)
120
+ text_split.insert(best_position, best_prediction)
121
+ if k == 1:
122
+ text = f"The most likely missing species is {best_predictions[0]} (position {best_positions[0]})."
123
+ else:
124
+ text = f"The most likely missing species are {', '.join(best_predictions[:-1])} and {best_predictions[-1]} (positions {', '.join(map(str, best_positions[:-1]))} and {best_positions[-1]})."
125
+ text += f"\nThe new vegetation plot is {best_sentences[-1]}. (see image of the most likely species below)."
126
  image = return_species_image(best_prediction)
127
  return text, image
128
 
 
135
  with gr.Row():
136
  with gr.Column():
137
  species_classification = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.")
138
+ k_classification = gr.Slider(1, 5, value=1, step=1, label="Top-k", info="Choose the number of top habitats to display.")
139
  with gr.Column():
140
+ text_classification = gr.Textbox()
141
+ image_classification = gr.Image()
142
  button_classification = gr.Button("Classify")
143
  gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""")
144
+ gr.Examples([["sparganium erectum, calystegia sepium, persicaria amphibia", 1]], [species_classification, k_classification], [text_classification, image_classification], classification, True)
145
 
146
  with gr.Tab("Missing species finding"):
147
  gr.Markdown("""<h3 style="text-align: center;">Finding the missing species!</h3>""")
148
  with gr.Row():
149
  species_masking = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.")
150
+ k_masking = gr.Slider(1, 5, value=1, step=1, label="Top-k", info="Choose the number of top missing species to find.")
151
  with gr.Column():
152
+ text_masking = gr.Textbox()
153
+ image_masking = gr.Image()
154
  button_masking = gr.Button("Find")
155
  gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""")
156
+ gr.Examples([["vaccinium myrtillus, dryopteris dilatata, molinia caerulea", 1]], [species_masking, k_masking], [text_masking, image_masking], masking, True)
157
 
158
+ button_classification.click(classification, inputs=[species_classification, k_classification], outputs=[textclassification, image_classification])
159
+ button_masking.click(masking, inputs=[species_masking, k_masking], outputs=[text_masking, image_masking])
160
 
161
  demo.launch()