mnavas commited on
Commit
12e78de
·
1 Parent(s): 20a65d7
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -23,14 +23,14 @@ def askcpv(description):
23
  encoding = {k: v.to(model.device) for k,v in encoding.items()}
24
  outputs = model(**encoding)
25
  sigmoid = torch.nn.Sigmoid()
26
- probs = sigmoid(outputs.logits.squeeze().cpu()).detach().numpy()
27
  values, indices = torch.topk(probs, k=10)
28
  # turn predicted id's into actual label names
29
  # predicted_labels = [id2label[idx] for idx, label in enumerate(predictions) if label == 1.0]
30
  # return predicted_labels
31
- print("probs: ", probs)
32
- print("values: ", values)
33
- print("indices: ", indices)
34
  print({i: v.item() for i, v in zip(indices, values)})
35
  return {cpv[i]: v.item() for i, v in zip(indices, values)}
36
 
 
23
  encoding = {k: v.to(model.device) for k,v in encoding.items()}
24
  outputs = model(**encoding)
25
  sigmoid = torch.nn.Sigmoid()
26
+ probs = sigmoid(outputs.logits.squeeze().cpu())
27
  values, indices = torch.topk(probs, k=10)
28
  # turn predicted id's into actual label names
29
  # predicted_labels = [id2label[idx] for idx, label in enumerate(predictions) if label == 1.0]
30
  # return predicted_labels
31
+ print("probs: ", probs.detach().numpy())
32
+ print("values: ", values.detach().numpy())
33
+ print("indices: ", indices.detach().numpy())
34
  print({i: v.item() for i, v in zip(indices, values)})
35
  return {cpv[i]: v.item() for i, v in zip(indices, values)}
36