Spaces:
Runtime error
Runtime error
Create train_q1.py
Browse files- train_q1.py +132 -0
train_q1.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# train.py
|
2 |
+
# --------------------------------------------------
|
3 |
+
# Training script with:
|
4 |
+
# * bugβfree evaluation (no IndexError)
|
5 |
+
# * faster throughput (preβtokenised data, DataLoader workers)
|
6 |
+
# * higher GPU utilisation (larger batch, torch.compile, TF32)
|
7 |
+
# --------------------------------------------------
|
8 |
+
import os, math, time, torch, torch.nn.functional as F
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
from transformers import (
|
11 |
+
AutoTokenizer,
|
12 |
+
get_linear_schedule_with_warmup,
|
13 |
+
)
|
14 |
+
|
15 |
+
from mcqa_dataset import MCQADataset
|
16 |
+
from mcqa_bert import MCQABERT
|
17 |
+
|
18 |
+
|
19 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
20 |
+
def run_split(model, loader):
|
21 |
+
"""
|
22 |
+
Returns: (binary CLS accuracy, 4βway accuracy)
|
23 |
+
Bug fixed β no more IndexError.
|
24 |
+
"""
|
25 |
+
model.eval()
|
26 |
+
bin_correct = tot = 0
|
27 |
+
per_qid = {} # {qid: [(cid, logit, gold_flag), β¦]}
|
28 |
+
|
29 |
+
with torch.no_grad():
|
30 |
+
for batch in loader:
|
31 |
+
ids = batch["input_ids"].cuda()
|
32 |
+
mask = batch["attention_mask"].cuda()
|
33 |
+
label = batch["label"].cuda()
|
34 |
+
|
35 |
+
logits = model(ids, mask) # (B)
|
36 |
+
preds = (torch.sigmoid(logits) > 0.5).long()
|
37 |
+
|
38 |
+
bin_correct += (preds == label).sum().item()
|
39 |
+
tot += len(label)
|
40 |
+
|
41 |
+
# stash logits for 4βway metric
|
42 |
+
for qid, cid, logit, gold_flag in zip(
|
43 |
+
batch["qid"], batch["cid"], logits.cpu(), batch["label"]
|
44 |
+
):
|
45 |
+
per_qid.setdefault(qid, []).append(
|
46 |
+
(cid, logit.item(), gold_flag.item())
|
47 |
+
)
|
48 |
+
|
49 |
+
correct4 = 0
|
50 |
+
for qid, opts in per_qid.items():
|
51 |
+
pred_cid = max(opts, key=lambda x: x[1])[0] # highest logit
|
52 |
+
gold_cid = [cid for cid, _, flag in opts if flag == 1][0]
|
53 |
+
if pred_cid == gold_cid:
|
54 |
+
correct4 += 1
|
55 |
+
|
56 |
+
return bin_correct / tot, correct4 / len(per_qid)
|
57 |
+
|
58 |
+
|
59 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
60 |
+
def main():
|
61 |
+
# ---------------- data -----------------
|
62 |
+
tok = AutoTokenizer.from_pretrained("bert-base-uncased")
|
63 |
+
|
64 |
+
train_ds = MCQADataset("train_complete.jsonl", tok)
|
65 |
+
val_ds = MCQADataset("valid_complete.jsonl", tok)
|
66 |
+
test_ds = MCQADataset("test_complete.jsonl", tok)
|
67 |
+
|
68 |
+
train_loader = DataLoader(
|
69 |
+
train_ds, batch_size=64, shuffle=True,
|
70 |
+
num_workers=4, pin_memory=True, persistent_workers=True
|
71 |
+
)
|
72 |
+
val_loader = DataLoader(
|
73 |
+
val_ds, batch_size=128, num_workers=4,
|
74 |
+
pin_memory=True, persistent_workers=True
|
75 |
+
)
|
76 |
+
test_loader = DataLoader(
|
77 |
+
test_ds, batch_size=128, num_workers=4,
|
78 |
+
pin_memory=True, persistent_workers=True
|
79 |
+
)
|
80 |
+
|
81 |
+
# ---------------- model -----------------
|
82 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
83 |
+
torch.set_float32_matmul_precision("high")
|
84 |
+
|
85 |
+
model = MCQABERT().cuda()
|
86 |
+
# Optional: compile for extra speed (PyTorchΒ β₯β―2.1)
|
87 |
+
if hasattr(torch, "compile"):
|
88 |
+
model = torch.compile(model)
|
89 |
+
|
90 |
+
# AdamWΒ (fused=True on PyTorchΒ β₯β―2.2, else falls back)
|
91 |
+
fused_ok = "fused" in torch.optim.AdamW.__init__.__code__.co_varnames
|
92 |
+
optimizer = torch.optim.AdamW(
|
93 |
+
model.parameters(), lr=3e-5, fused=fused_ok
|
94 |
+
)
|
95 |
+
|
96 |
+
total_steps = len(train_loader) * 3 # 3 epochs
|
97 |
+
scheduler = get_linear_schedule_with_warmup(
|
98 |
+
optimizer, int(0.1 * total_steps), total_steps
|
99 |
+
)
|
100 |
+
|
101 |
+
# ---------------- training -----------------
|
102 |
+
for epoch in range(1, 4):
|
103 |
+
model.train()
|
104 |
+
t0 = time.time()
|
105 |
+
for batch in train_loader:
|
106 |
+
ids = batch["input_ids"].cuda(non_blocking=True)
|
107 |
+
mask = batch["attention_mask"].cuda(non_blocking=True)
|
108 |
+
label = batch["label"].float().cuda(non_blocking=True)
|
109 |
+
|
110 |
+
logits = model(ids, mask)
|
111 |
+
loss = F.binary_cross_entropy_with_logits(logits, label)
|
112 |
+
|
113 |
+
loss.backward()
|
114 |
+
optimizer.step()
|
115 |
+
scheduler.step()
|
116 |
+
optimizer.zero_grad(set_to_none=True)
|
117 |
+
|
118 |
+
dur = time.time() - t0
|
119 |
+
bin_acc, mc_acc = run_split(model, val_loader)
|
120 |
+
print(f"Epoch {epoch}: "
|
121 |
+
f"valβCLS={bin_acc:.3f} | valβ4way={mc_acc:.3f} "
|
122 |
+
f"| time={dur/60:.1f}β―min")
|
123 |
+
|
124 |
+
# ---------------- test -----------------
|
125 |
+
_, test_acc = run_split(model, test_loader)
|
126 |
+
mem = torch.cuda.max_memory_allocated() / (1024 ** 3)
|
127 |
+
print(f"Test 4βway accuracy = {test_acc:.3f}")
|
128 |
+
print(f"Peak GPU memory = {mem:.1f}β―GB")
|
129 |
+
|
130 |
+
|
131 |
+
if __name__ == "__main__":
|
132 |
+
main()
|