kevinwang676 commited on
Commit
d4069e8
Β·
verified Β·
1 Parent(s): d831168

Create train_q1.py

Browse files
Files changed (1) hide show
  1. train_q1.py +132 -0
train_q1.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train.py
2
+ # --------------------------------------------------
3
+ # Training script with:
4
+ # * bug‑free evaluation (no IndexError)
5
+ # * faster throughput (pre‑tokenised data, DataLoader workers)
6
+ # * higher GPU utilisation (larger batch, torch.compile, TF32)
7
+ # --------------------------------------------------
8
+ import os, math, time, torch, torch.nn.functional as F
9
+ from torch.utils.data import DataLoader
10
+ from transformers import (
11
+ AutoTokenizer,
12
+ get_linear_schedule_with_warmup,
13
+ )
14
+
15
+ from mcqa_dataset import MCQADataset
16
+ from mcqa_bert import MCQABERT
17
+
18
+
19
+ # ──────────────────────────────────────────────────────────────────────────
20
+ def run_split(model, loader):
21
+ """
22
+ Returns: (binary CLS accuracy, 4‑way accuracy)
23
+ Bug fixed – no more IndexError.
24
+ """
25
+ model.eval()
26
+ bin_correct = tot = 0
27
+ per_qid = {} # {qid: [(cid, logit, gold_flag), …]}
28
+
29
+ with torch.no_grad():
30
+ for batch in loader:
31
+ ids = batch["input_ids"].cuda()
32
+ mask = batch["attention_mask"].cuda()
33
+ label = batch["label"].cuda()
34
+
35
+ logits = model(ids, mask) # (B)
36
+ preds = (torch.sigmoid(logits) > 0.5).long()
37
+
38
+ bin_correct += (preds == label).sum().item()
39
+ tot += len(label)
40
+
41
+ # stash logits for 4‑way metric
42
+ for qid, cid, logit, gold_flag in zip(
43
+ batch["qid"], batch["cid"], logits.cpu(), batch["label"]
44
+ ):
45
+ per_qid.setdefault(qid, []).append(
46
+ (cid, logit.item(), gold_flag.item())
47
+ )
48
+
49
+ correct4 = 0
50
+ for qid, opts in per_qid.items():
51
+ pred_cid = max(opts, key=lambda x: x[1])[0] # highest logit
52
+ gold_cid = [cid for cid, _, flag in opts if flag == 1][0]
53
+ if pred_cid == gold_cid:
54
+ correct4 += 1
55
+
56
+ return bin_correct / tot, correct4 / len(per_qid)
57
+
58
+
59
+ # ──────────────────────────────────────────────────────────────────────────
60
+ def main():
61
+ # ---------------- data -----------------
62
+ tok = AutoTokenizer.from_pretrained("bert-base-uncased")
63
+
64
+ train_ds = MCQADataset("train_complete.jsonl", tok)
65
+ val_ds = MCQADataset("valid_complete.jsonl", tok)
66
+ test_ds = MCQADataset("test_complete.jsonl", tok)
67
+
68
+ train_loader = DataLoader(
69
+ train_ds, batch_size=64, shuffle=True,
70
+ num_workers=4, pin_memory=True, persistent_workers=True
71
+ )
72
+ val_loader = DataLoader(
73
+ val_ds, batch_size=128, num_workers=4,
74
+ pin_memory=True, persistent_workers=True
75
+ )
76
+ test_loader = DataLoader(
77
+ test_ds, batch_size=128, num_workers=4,
78
+ pin_memory=True, persistent_workers=True
79
+ )
80
+
81
+ # ---------------- model -----------------
82
+ torch.backends.cuda.matmul.allow_tf32 = True
83
+ torch.set_float32_matmul_precision("high")
84
+
85
+ model = MCQABERT().cuda()
86
+ # Optional: compile for extra speed (PyTorchΒ β‰₯β€―2.1)
87
+ if hasattr(torch, "compile"):
88
+ model = torch.compile(model)
89
+
90
+ # AdamWΒ (fused=True on PyTorchΒ β‰₯β€―2.2, else falls back)
91
+ fused_ok = "fused" in torch.optim.AdamW.__init__.__code__.co_varnames
92
+ optimizer = torch.optim.AdamW(
93
+ model.parameters(), lr=3e-5, fused=fused_ok
94
+ )
95
+
96
+ total_steps = len(train_loader) * 3 # 3 epochs
97
+ scheduler = get_linear_schedule_with_warmup(
98
+ optimizer, int(0.1 * total_steps), total_steps
99
+ )
100
+
101
+ # ---------------- training -----------------
102
+ for epoch in range(1, 4):
103
+ model.train()
104
+ t0 = time.time()
105
+ for batch in train_loader:
106
+ ids = batch["input_ids"].cuda(non_blocking=True)
107
+ mask = batch["attention_mask"].cuda(non_blocking=True)
108
+ label = batch["label"].float().cuda(non_blocking=True)
109
+
110
+ logits = model(ids, mask)
111
+ loss = F.binary_cross_entropy_with_logits(logits, label)
112
+
113
+ loss.backward()
114
+ optimizer.step()
115
+ scheduler.step()
116
+ optimizer.zero_grad(set_to_none=True)
117
+
118
+ dur = time.time() - t0
119
+ bin_acc, mc_acc = run_split(model, val_loader)
120
+ print(f"Epoch {epoch}: "
121
+ f"val‑CLS={bin_acc:.3f} | val‑4way={mc_acc:.3f} "
122
+ f"| time={dur/60:.1f}β€―min")
123
+
124
+ # ---------------- test -----------------
125
+ _, test_acc = run_split(model, test_loader)
126
+ mem = torch.cuda.max_memory_allocated() / (1024 ** 3)
127
+ print(f"Test 4‑way accuracy = {test_acc:.3f}")
128
+ print(f"Peak GPU memory = {mem:.1f}β€―GB")
129
+
130
+
131
+ if __name__ == "__main__":
132
+ main()