Spaces:
Sleeping
Sleeping
Commit
·
750c18c
1
Parent(s):
7171d1b
Update accessibility_classifier/classifier/model.py
Browse files
accessibility_classifier/classifier/model.py
CHANGED
@@ -92,10 +92,8 @@ class Model:
|
|
92 |
with torch.no_grad():
|
93 |
probabilities = self.classifier(input_ids, attention_mask)
|
94 |
|
95 |
-
prediction = F.softmax(probabilities.logits,
|
96 |
-
|
97 |
-
prediction_index = np.where(F.softmax(probabilities.logits,
|
98 |
-
dim=1).cpu().numpy() == prediction)[1][0]
|
99 |
label = self.labels[prediction_index]
|
100 |
|
101 |
all_predictions = F.softmax(
|
|
|
92 |
with torch.no_grad():
|
93 |
probabilities = self.classifier(input_ids, attention_mask)
|
94 |
|
95 |
+
prediction = F.softmax(probabilities.logits, dim=1).cpu().numpy().flatten().max()
|
96 |
+
prediction_index = np.where(F.softmax(probabilities.logits, dim=1).cpu().numpy() == prediction)[1][0]
|
|
|
|
|
97 |
label = self.labels[prediction_index]
|
98 |
|
99 |
all_predictions = F.softmax(
|