Spaces:
Sleeping
Sleeping
Brice Vandeputte
commited on
Commit
•
ffa34e1
1
Parent(s):
e04fcaa
add new method api_classification
Browse files
app.py
CHANGED
@@ -153,6 +153,41 @@ def open_domain_classification(img, rank: int) -> dict[str, float]:
|
|
153 |
return {name: output[name] for name in topk_names}
|
154 |
|
155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
def change_output(choice):
|
157 |
return gr.Label(num_top_classes=k, label=ranks[choice], show_label=True, value=None)
|
158 |
|
|
|
153 |
return {name: output[name] for name in topk_names}
|
154 |
|
155 |
|
156 |
+
|
157 |
+
@torch.no_grad()
|
158 |
+
def api_classification(img, rank: int) -> dict[str, float]:
|
159 |
+
"""
|
160 |
+
Predicts from the entire tree of life.
|
161 |
+
If targeting a higher rank than species, then this function predicts among all
|
162 |
+
species, then sums up species-level probabilities for the given rank.
|
163 |
+
"""
|
164 |
+
img = preprocess_img(img).to(device)
|
165 |
+
img_features = model.encode_image(img.unsqueeze(0))
|
166 |
+
img_features = F.normalize(img_features, dim=-1)
|
167 |
+
|
168 |
+
logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze()
|
169 |
+
probs = F.softmax(logits, dim=0)
|
170 |
+
|
171 |
+
# If predicting species, no need to sum probabilities.
|
172 |
+
if rank + 1 == len(ranks):
|
173 |
+
topk = probs.topk(k)
|
174 |
+
return {
|
175 |
+
format_name(*txt_names[i]): prob for i, prob in zip(topk.indices, topk.values)
|
176 |
+
}
|
177 |
+
|
178 |
+
# Sum up by the rank
|
179 |
+
output = collections.defaultdict(float)
|
180 |
+
for i in torch.nonzero(probs > min_prob).squeeze():
|
181 |
+
output[" ".join(txt_names[i][0][: rank + 1])] += probs[i]
|
182 |
+
|
183 |
+
logger.info(">>>>")
|
184 |
+
logger.info(probs[0])
|
185 |
+
|
186 |
+
topk_names = heapq.nlargest(k, output, key=output.get)
|
187 |
+
|
188 |
+
return {name: output[name] for name in topk_names}
|
189 |
+
|
190 |
+
|
191 |
def change_output(choice):
|
192 |
return gr.Label(num_top_classes=k, label=ranks[choice], show_label=True, value=None)
|
193 |
|