NaolTaye commited on
Commit
5ce8ded
·
1 Parent(s): e640128
Files changed (1) hide show
  1. tasks/text.py +6 -7
tasks/text.py CHANGED
@@ -89,13 +89,12 @@ async def evaluate_text(request: TextEvaluationRequest):
89
 
90
  with torch.no_grad():
91
  print('BEFORE PREDICTION')
92
- for batch in dataloader:
93
- print('INSIDE PREDICTION')
94
- test_input_ids = batch["input_ids"].to(device)
95
- test_attention_mask = batch["attention_mask"].to(device)
96
- outputs = model(test_input_ids, test_attention_mask)
97
- p = torch.argmax(outputs.logits, dim=1)
98
- predictions = np.append(predictions, p.cpu().numpy())
99
 
100
  print("Finished prediction run")
101
 
 
89
 
90
  with torch.no_grad():
91
  print('BEFORE PREDICTION')
92
+
93
+ test_input_ids = tokenized_test["input_ids"].to(device)
94
+ test_attention_mask = tokenized_test["attention_mask"].to(device)
95
+ outputs = model(test_input_ids, test_attention_mask)
96
+ p = torch.argmax(outputs.logits, dim=1)
97
+ predictions = np.append(predictions, p.cpu().numpy())
 
98
 
99
  print("Finished prediction run")
100