test_step(test_loader, model)

Eval step.

newsclassifier\inference.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def test_step(test_loader: DataLoader, model) -> Tuple[np.ndarray, np.ndarray]:
    """Eval step."""
    model.eval()
    y_trues, y_preds = [], []
    with torch.inference_mode():
        for step, (inputs, labels) in tqdm(enumerate(test_loader)):
            inputs = collate(inputs)
            for k, v in inputs.items():
                inputs[k] = v.to(device)
            labels = labels.to(device)
            y_pred = model(inputs)
            y_trues.extend(labels.cpu().numpy())
            y_preds.extend(torch.argmax(y_pred, dim=1).cpu().numpy())
    return np.vstack(y_trues), np.vstack(y_preds)