Refactor prediction function: return top 3 class probabilities with corresponding class names
Browse files
app.py
CHANGED
@@ -69,10 +69,15 @@ def predict(data):
|
|
69 |
|
70 |
preds = preds[0]
|
71 |
print(preds)
|
|
|
|
|
|
|
72 |
|
|
|
|
|
|
|
73 |
|
74 |
-
|
75 |
-
return {label: float(pred) for label, pred in zip(labels, preds)}
|
76 |
|
77 |
# Top 3 classes
|
78 |
label = gr.Label(num_top_classes=3)
|
|
|
69 |
|
70 |
preds = preds[0]
|
71 |
print(preds)
|
72 |
+
|
73 |
+
top_3_classes = np.argsort(preds)[-3:][::-1]
|
74 |
+
top_3_probs = preds[top_3_classes]
|
75 |
|
76 |
+
class_names = [labels[i] for i in top_3_classes]
|
77 |
+
|
78 |
+
print(class_names, top_3_probs, top_3_classes)
|
79 |
|
80 |
+
return {class_names[i]: top_3_probs[i] for i in range(3)}
|
|
|
81 |
|
82 |
# Top 3 classes
|
83 |
label = gr.Label(num_top_classes=3)
|