0-ma commited on
Commit
6f08475
·
verified ·
1 Parent(s): 83d84bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -22,10 +22,11 @@ def predict(image):
22
  inputs = feature_extractor(images=[image], return_tensors="pt")
23
  logits = model(**inputs)['logits'].cpu().detach().numpy()
24
  predictions = np.argmax(logits, axis=1)
25
- predicted_labels = [labels[prediction] for prediction in predictions]
26
- print(predicted_labels[0],logits[0][predictions[0]])
27
-
28
 
 
 
29
  return {"predicted_label" : predicted_labels[0] }
30
 
31
  title = "Geometric Shape Classifier"
 
22
  inputs = feature_extractor(images=[image], return_tensors="pt")
23
  logits = model(**inputs)['logits'].cpu().detach().numpy()
24
  predictions = np.argmax(logits, axis=1)
25
+ #predicted_labels = [labels[prediction] for prediction in predictions]
26
+ #print(predicted_labels[0],logits[0][predictions[0]])
 
27
 
28
+ return {"predictions" : predictions }
29
+
30
  return {"predicted_label" : predicted_labels[0] }
31
 
32
  title = "Geometric Shape Classifier"