xcurvnubaim commited on
Commit
1f96cd2
·
1 Parent(s): fba92a9

feat: return top 3

Browse files
Files changed (1) hide show
  1. main.py +4 -1
main.py CHANGED
@@ -19,7 +19,10 @@ def classify_image(img):
19
  img_array = tf.keras.applications.efficientnet.preprocess_input(img_array)
20
  prediction = model.predict(img_array).flatten()
21
  confidences = {labels[i]: float(prediction[i]) for i in range(90)}
22
- return confidences
 
 
 
23
 
24
  @app.post("/predict")
25
  async def predict(file: bytes = File(...)):
 
19
  img_array = tf.keras.applications.efficientnet.preprocess_input(img_array)
20
  prediction = model.predict(img_array).flatten()
21
  confidences = {labels[i]: float(prediction[i]) for i in range(90)}
22
+ # Sort the confidences dictionary by value and get the top 3 items
23
+ top_3_confidences = dict(sorted(confidences.items(), key=lambda item: item[1], reverse=True)[:3])
24
+
25
+ return top_3_confidences
26
 
27
  @app.post("/predict")
28
  async def predict(file: bytes = File(...)):