malmukhtar commited on
Commit
750c18c
·
1 Parent(s): 7171d1b

Update accessibility_classifier/classifier/model.py

Browse files
accessibility_classifier/classifier/model.py CHANGED
@@ -92,10 +92,8 @@ class Model:
92
  with torch.no_grad():
93
  probabilities = self.classifier(input_ids, attention_mask)
94
 
95
- prediction = F.softmax(probabilities.logits,
96
- dim=1).cpu().numpy().flatten().max()
97
- prediction_index = np.where(F.softmax(probabilities.logits,
98
- dim=1).cpu().numpy() == prediction)[1][0]
99
  label = self.labels[prediction_index]
100
 
101
  all_predictions = F.softmax(
 
92
  with torch.no_grad():
93
  probabilities = self.classifier(input_ids, attention_mask)
94
 
95
+ prediction = F.softmax(probabilities.logits, dim=1).cpu().numpy().flatten().max()
96
+ prediction_index = np.where(F.softmax(probabilities.logits, dim=1).cpu().numpy() == prediction)[1][0]
 
 
97
  label = self.labels[prediction_index]
98
 
99
  all_predictions = F.softmax(