File size: 4,875 Bytes
d4069e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# train.py
# --------------------------------------------------
# Training script with:
#   * bug‑free evaluation (no IndexError)
#   * faster throughput (pre‑tokenised data, DataLoader workers)
#   * higher GPU utilisation (larger batch, torch.compile, TF32)
# --------------------------------------------------
import os, math, time, torch, torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import (
    AutoTokenizer,
    get_linear_schedule_with_warmup,
)

from mcqa_dataset import MCQADataset
from mcqa_bert    import MCQABERT


# ──────────────────────────────────────────────────────────────────────────
def run_split(model, loader):
    """
    Returns: (binary CLS accuracy, 4‑way accuracy)
    Bug fixed – no more IndexError.
    """
    model.eval()
    bin_correct = tot = 0
    per_qid = {}                                # {qid: [(cid, logit, gold_flag), …]}

    with torch.no_grad():
        for batch in loader:
            ids   = batch["input_ids"].cuda()
            mask  = batch["attention_mask"].cuda()
            label = batch["label"].cuda()

            logits = model(ids, mask)           # (B)
            preds  = (torch.sigmoid(logits) > 0.5).long()

            bin_correct += (preds == label).sum().item()
            tot         += len(label)

            # stash logits for 4‑way metric
            for qid, cid, logit, gold_flag in zip(
                batch["qid"], batch["cid"], logits.cpu(), batch["label"]
            ):
                per_qid.setdefault(qid, []).append(
                    (cid, logit.item(), gold_flag.item())
                )

    correct4 = 0
    for qid, opts in per_qid.items():
        pred_cid = max(opts, key=lambda x: x[1])[0]            # highest logit
        gold_cid = [cid for cid, _, flag in opts if flag == 1][0]
        if pred_cid == gold_cid:
            correct4 += 1

    return bin_correct / tot, correct4 / len(per_qid)


# ──────────────────────────────────────────────────────────────────────────
def main():
    # ---------------- data -----------------
    tok = AutoTokenizer.from_pretrained("bert-base-uncased")

    train_ds = MCQADataset("train_complete.jsonl", tok)
    val_ds   = MCQADataset("valid_complete.jsonl", tok)
    test_ds  = MCQADataset("test_complete.jsonl",  tok)

    train_loader = DataLoader(
        train_ds, batch_size=64, shuffle=True,
        num_workers=4, pin_memory=True, persistent_workers=True
    )
    val_loader = DataLoader(
        val_ds, batch_size=128, num_workers=4,
        pin_memory=True, persistent_workers=True
    )
    test_loader = DataLoader(
        test_ds, batch_size=128, num_workers=4,
        pin_memory=True, persistent_workers=True
    )

    # ---------------- model -----------------
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.set_float32_matmul_precision("high")

    model = MCQABERT().cuda()
    # Optional: compile for extra speed (PyTorchΒ β‰₯β€―2.1)
    if hasattr(torch, "compile"):
        model = torch.compile(model)

    # AdamWΒ (fused=True on PyTorchΒ β‰₯β€―2.2, else falls back)
    fused_ok = "fused" in torch.optim.AdamW.__init__.__code__.co_varnames
    optimizer = torch.optim.AdamW(
        model.parameters(), lr=3e-5, fused=fused_ok
    )

    total_steps = len(train_loader) * 3          # 3 epochs
    scheduler   = get_linear_schedule_with_warmup(
        optimizer, int(0.1 * total_steps), total_steps
    )

    # ---------------- training -----------------
    for epoch in range(1, 4):
        model.train()
        t0 = time.time()
        for batch in train_loader:
            ids   = batch["input_ids"].cuda(non_blocking=True)
            mask  = batch["attention_mask"].cuda(non_blocking=True)
            label = batch["label"].float().cuda(non_blocking=True)

            logits = model(ids, mask)
            loss   = F.binary_cross_entropy_with_logits(logits, label)

            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad(set_to_none=True)

        dur = time.time() - t0
        bin_acc, mc_acc = run_split(model, val_loader)
        print(f"Epoch {epoch}: "
              f"val‑CLS={bin_acc:.3f} | val‑4way={mc_acc:.3f} "
              f"| time={dur/60:.1f}β€―min")

    # ---------------- test -----------------
    _, test_acc = run_split(model, test_loader)
    mem = torch.cuda.max_memory_allocated() / (1024 ** 3)
    print(f"Test 4‑way accuracy = {test_acc:.3f}")
    print(f"Peak GPU memory     = {mem:.1f}β€―GB")


if __name__ == "__main__":
    main()