0-ma commited on
Commit
91d639b
·
verified ·
1 Parent(s): 6f08475

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -15,19 +15,20 @@ labels = [
15
  feature_extractor = AutoImageProcessor.from_pretrained('0-ma/vit-geometric-shapes-tiny')
16
  model = AutoModelForImageClassification.from_pretrained('0-ma/vit-geometric-shapes-tiny')
17
 
18
- labels = []
19
  def predict(image):
20
  feature_extractor = AutoImageProcessor.from_pretrained('0-ma/vit-geometric-shapes-tiny')
21
  model = AutoModelForImageClassification.from_pretrained('0-ma/vit-geometric-shapes-tiny')
22
  inputs = feature_extractor(images=[image], return_tensors="pt")
23
  logits = model(**inputs)['logits'].cpu().detach().numpy()
24
- predictions = np.argmax(logits, axis=1)
25
  #predicted_labels = [labels[prediction] for prediction in predictions]
26
  #print(predicted_labels[0],logits[0][predictions[0]])
27
-
28
- return {"predictions" : predictions }
 
29
 
30
- return {"predicted_label" : predicted_labels[0] }
31
 
32
  title = "Geometric Shape Classifier"
33
  description = "A geometric shape setector."
 
15
  feature_extractor = AutoImageProcessor.from_pretrained('0-ma/vit-geometric-shapes-tiny')
16
  model = AutoModelForImageClassification.from_pretrained('0-ma/vit-geometric-shapes-tiny')
17
 
18
+
19
  def predict(image):
20
  feature_extractor = AutoImageProcessor.from_pretrained('0-ma/vit-geometric-shapes-tiny')
21
  model = AutoModelForImageClassification.from_pretrained('0-ma/vit-geometric-shapes-tiny')
22
  inputs = feature_extractor(images=[image], return_tensors="pt")
23
  logits = model(**inputs)['logits'].cpu().detach().numpy()
24
+ #predictions = np.argmax(logits, axis=1)
25
  #predicted_labels = [labels[prediction] for prediction in predictions]
26
  #print(predicted_labels[0],logits[0][predictions[0]])
27
+ confidences = {labels[i]: float(logits[i]) for i in range(len(labels))}
28
+ return confidences
29
+ #return {"predictions" : predictions }
30
 
31
+ #return {"predicted_label" : predicted_labels[0] }
32
 
33
  title = "Geometric Shape Classifier"
34
  description = "A geometric shape setector."