# train.py # -------------------------------------------------- # Training script with: # * bug‑free evaluation (no IndexError) # * faster throughput (pre‑tokenised data, DataLoader workers) # * higher GPU utilisation (larger batch, torch.compile, TF32) # -------------------------------------------------- import os, math, time, torch, torch.nn.functional as F from torch.utils.data import DataLoader from transformers import ( AutoTokenizer, get_linear_schedule_with_warmup, ) from mcqa_dataset import MCQADataset from mcqa_bert import MCQABERT # ────────────────────────────────────────────────────────────────────────── def run_split(model, loader): """ Returns: (binary CLS accuracy, 4‑way accuracy) Bug fixed – no more IndexError. """ model.eval() bin_correct = tot = 0 per_qid = {} # {qid: [(cid, logit, gold_flag), …]} with torch.no_grad(): for batch in loader: ids = batch["input_ids"].cuda() mask = batch["attention_mask"].cuda() label = batch["label"].cuda() logits = model(ids, mask) # (B) preds = (torch.sigmoid(logits) > 0.5).long() bin_correct += (preds == label).sum().item() tot += len(label) # stash logits for 4‑way metric for qid, cid, logit, gold_flag in zip( batch["qid"], batch["cid"], logits.cpu(), batch["label"] ): per_qid.setdefault(qid, []).append( (cid, logit.item(), gold_flag.item()) ) correct4 = 0 for qid, opts in per_qid.items(): pred_cid = max(opts, key=lambda x: x[1])[0] # highest logit gold_cid = [cid for cid, _, flag in opts if flag == 1][0] if pred_cid == gold_cid: correct4 += 1 return bin_correct / tot, correct4 / len(per_qid) # ────────────────────────────────────────────────────────────────────────── def main(): # ---------------- data ----------------- tok = AutoTokenizer.from_pretrained("bert-base-uncased") train_ds = MCQADataset("train_complete.jsonl", tok) val_ds = MCQADataset("valid_complete.jsonl", tok) test_ds = MCQADataset("test_complete.jsonl", tok) train_loader = DataLoader( train_ds, batch_size=64, shuffle=True, num_workers=4, pin_memory=True, persistent_workers=True ) val_loader = DataLoader( val_ds, batch_size=128, num_workers=4, pin_memory=True, persistent_workers=True ) test_loader = DataLoader( test_ds, batch_size=128, num_workers=4, pin_memory=True, persistent_workers=True ) # ---------------- model ----------------- torch.backends.cuda.matmul.allow_tf32 = True torch.set_float32_matmul_precision("high") model = MCQABERT().cuda() # Optional: compile for extra speed (PyTorch ≥ 2.1) if hasattr(torch, "compile"): model = torch.compile(model) # AdamW (fused=True on PyTorch ≥ 2.2, else falls back) fused_ok = "fused" in torch.optim.AdamW.__init__.__code__.co_varnames optimizer = torch.optim.AdamW( model.parameters(), lr=3e-5, fused=fused_ok ) total_steps = len(train_loader) * 3 # 3 epochs scheduler = get_linear_schedule_with_warmup( optimizer, int(0.1 * total_steps), total_steps ) # ---------------- training ----------------- for epoch in range(1, 4): model.train() t0 = time.time() for batch in train_loader: ids = batch["input_ids"].cuda(non_blocking=True) mask = batch["attention_mask"].cuda(non_blocking=True) label = batch["label"].float().cuda(non_blocking=True) logits = model(ids, mask) loss = F.binary_cross_entropy_with_logits(logits, label) loss.backward() optimizer.step() scheduler.step() optimizer.zero_grad(set_to_none=True) dur = time.time() - t0 bin_acc, mc_acc = run_split(model, val_loader) print(f"Epoch {epoch}: " f"val‑CLS={bin_acc:.3f} | val‑4way={mc_acc:.3f} " f"| time={dur/60:.1f} min") # ---------------- test ----------------- _, test_acc = run_split(model, test_loader) mem = torch.cuda.max_memory_allocated() / (1024 ** 3) print(f"Test 4‑way accuracy = {test_acc:.3f}") print(f"Peak GPU memory = {mem:.1f} GB") if __name__ == "__main__": main()