alibayram commited on
Commit
d6bce58
·
1 Parent(s): 363a74f

Refactor prediction function: return top 3 class probabilities with corresponding class names

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -69,10 +69,15 @@ def predict(data):
69
 
70
  preds = preds[0]
71
  print(preds)
 
 
 
72
 
 
 
 
73
 
74
- # Return the probability for each class
75
- return {label: float(pred) for label, pred in zip(labels, preds)}
76
 
77
  # Top 3 classes
78
  label = gr.Label(num_top_classes=3)
 
69
 
70
  preds = preds[0]
71
  print(preds)
72
+
73
+ top_3_classes = np.argsort(preds)[-3:][::-1]
74
+ top_3_probs = preds[top_3_classes]
75
 
76
+ class_names = [labels[i] for i in top_3_classes]
77
+
78
+ print(class_names, top_3_probs, top_3_classes)
79
 
80
+ return {class_names[i]: top_3_probs[i] for i in range(3)}
 
81
 
82
  # Top 3 classes
83
  label = gr.Label(num_top_classes=3)