mnavas commited on
Commit
83254d7
·
1 Parent(s): f48f009
Files changed (1) hide show
  1. app.py +2 -3
app.py CHANGED
@@ -22,9 +22,8 @@ def askcpv(description):
22
  encoding = {k: v.to(model.device) for k,v in encoding.items()}
23
  outputs = model(**encoding)
24
  sigmoid = torch.nn.Sigmoid()
25
- probs = sigmoid(logits.squeeze().cpu())
26
- probabilites = torch.nn.functional.softmax(out[0], dim=0)
27
- values, indices = torch.topk(probabilites, k=10)
28
  # turn predicted id's into actual label names
29
  # predicted_labels = [id2label[idx] for idx, label in enumerate(predictions) if label == 1.0]
30
  # return predicted_labels
 
22
  encoding = {k: v.to(model.device) for k,v in encoding.items()}
23
  outputs = model(**encoding)
24
  sigmoid = torch.nn.Sigmoid()
25
+ probs = sigmoid(outputs.logits.squeeze().cpu())
26
+ values, indices = torch.topk(probs, k=10)
 
27
  # turn predicted id's into actual label names
28
  # predicted_labels = [id2label[idx] for idx, label in enumerate(predictions) if label == 1.0]
29
  # return predicted_labels