eval_step(val_loader, model, num_classes, loss_fn, epoch)

Eval step.

newsclassifier\train.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def eval_step(val_loader: DataLoader, model, num_classes: int, loss_fn, epoch: int) -> Tuple[float, np.ndarray, np.ndarray]:
    """Eval step."""
    model.eval()
    loss = 0.0
    total_iterations = len(val_loader)
    desc = f"Validation - Epoch {epoch+1}"
    y_trues, y_preds = [], []
    with torch.inference_mode():
        for step, (inputs, labels) in tqdm(enumerate(val_loader), total=total_iterations, desc=desc):
            inputs = collate(inputs)
            for k, v in inputs.items():
                inputs[k] = v.to(device)
            labels = labels.to(device)
            y_pred = model(inputs)
            targets = F.one_hot(labels.long(), num_classes=num_classes).float()  # one-hot (for loss_fn)
            J = loss_fn(y_pred, targets).item()
            loss += (J - loss) / (step + 1)
            y_trues.extend(targets.cpu().numpy())
            y_preds.extend(torch.argmax(y_pred, dim=1).cpu().numpy())
    return loss, np.vstack(y_trues), np.vstack(y_preds)

train_step(train_loader, model, num_classes, loss_fn, optimizer, epoch)

Train step.

newsclassifier\train.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def train_step(train_loader: DataLoader, model, num_classes: int, loss_fn, optimizer, epoch: int) -> float:
    """Train step."""
    model.train()
    loss = 0.0
    total_iterations = len(train_loader)
    desc = f"Training - Epoch {epoch+1}"
    for step, (inputs, labels) in tqdm(enumerate(train_loader), total=total_iterations, desc=desc):
        inputs = collate(inputs)
        for k, v in inputs.items():
            inputs[k] = v.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()  # reset gradients
        y_pred = model(inputs)  # forward pass
        targets = F.one_hot(labels.long(), num_classes=num_classes).float()  # one-hot (for loss_fn)
        J = loss_fn(y_pred, targets)  # define loss
        J.backward()  # backward pass
        optimizer.step()  # update weights
        loss += (J.detach().item() - loss) / (step + 1)  # cumulative loss
    return loss