Brice Vandeputte commited on
Commit
ffa34e1
1 Parent(s): e04fcaa

add new method api_classification

Browse files
Files changed (1) hide show
  1. app.py +35 -0
app.py CHANGED
@@ -153,6 +153,41 @@ def open_domain_classification(img, rank: int) -> dict[str, float]:
153
  return {name: output[name] for name in topk_names}
154
 
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  def change_output(choice):
157
  return gr.Label(num_top_classes=k, label=ranks[choice], show_label=True, value=None)
158
 
 
153
  return {name: output[name] for name in topk_names}
154
 
155
 
156
+
157
+ @torch.no_grad()
158
+ def api_classification(img, rank: int) -> dict[str, float]:
159
+ """
160
+ Predicts from the entire tree of life.
161
+ If targeting a higher rank than species, then this function predicts among all
162
+ species, then sums up species-level probabilities for the given rank.
163
+ """
164
+ img = preprocess_img(img).to(device)
165
+ img_features = model.encode_image(img.unsqueeze(0))
166
+ img_features = F.normalize(img_features, dim=-1)
167
+
168
+ logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze()
169
+ probs = F.softmax(logits, dim=0)
170
+
171
+ # If predicting species, no need to sum probabilities.
172
+ if rank + 1 == len(ranks):
173
+ topk = probs.topk(k)
174
+ return {
175
+ format_name(*txt_names[i]): prob for i, prob in zip(topk.indices, topk.values)
176
+ }
177
+
178
+ # Sum up by the rank
179
+ output = collections.defaultdict(float)
180
+ for i in torch.nonzero(probs > min_prob).squeeze():
181
+ output[" ".join(txt_names[i][0][: rank + 1])] += probs[i]
182
+
183
+ logger.info(">>>>")
184
+ logger.info(probs[0])
185
+
186
+ topk_names = heapq.nlargest(k, output, key=output.get)
187
+
188
+ return {name: output[name] for name in topk_names}
189
+
190
+
191
  def change_output(choice):
192
  return gr.Label(num_top_classes=k, label=ranks[choice], show_label=True, value=None)
193