Spaces:
Runtime error
Runtime error
File size: 4,875 Bytes
d4069e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# train.py
# --------------------------------------------------
# Training script with:
# * bugβfree evaluation (no IndexError)
# * faster throughput (preβtokenised data, DataLoader workers)
# * higher GPU utilisation (larger batch, torch.compile, TF32)
# --------------------------------------------------
import os, math, time, torch, torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import (
AutoTokenizer,
get_linear_schedule_with_warmup,
)
from mcqa_dataset import MCQADataset
from mcqa_bert import MCQABERT
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def run_split(model, loader):
"""
Returns: (binary CLS accuracy, 4βway accuracy)
Bug fixed β no more IndexError.
"""
model.eval()
bin_correct = tot = 0
per_qid = {} # {qid: [(cid, logit, gold_flag), β¦]}
with torch.no_grad():
for batch in loader:
ids = batch["input_ids"].cuda()
mask = batch["attention_mask"].cuda()
label = batch["label"].cuda()
logits = model(ids, mask) # (B)
preds = (torch.sigmoid(logits) > 0.5).long()
bin_correct += (preds == label).sum().item()
tot += len(label)
# stash logits for 4βway metric
for qid, cid, logit, gold_flag in zip(
batch["qid"], batch["cid"], logits.cpu(), batch["label"]
):
per_qid.setdefault(qid, []).append(
(cid, logit.item(), gold_flag.item())
)
correct4 = 0
for qid, opts in per_qid.items():
pred_cid = max(opts, key=lambda x: x[1])[0] # highest logit
gold_cid = [cid for cid, _, flag in opts if flag == 1][0]
if pred_cid == gold_cid:
correct4 += 1
return bin_correct / tot, correct4 / len(per_qid)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def main():
# ---------------- data -----------------
tok = AutoTokenizer.from_pretrained("bert-base-uncased")
train_ds = MCQADataset("train_complete.jsonl", tok)
val_ds = MCQADataset("valid_complete.jsonl", tok)
test_ds = MCQADataset("test_complete.jsonl", tok)
train_loader = DataLoader(
train_ds, batch_size=64, shuffle=True,
num_workers=4, pin_memory=True, persistent_workers=True
)
val_loader = DataLoader(
val_ds, batch_size=128, num_workers=4,
pin_memory=True, persistent_workers=True
)
test_loader = DataLoader(
test_ds, batch_size=128, num_workers=4,
pin_memory=True, persistent_workers=True
)
# ---------------- model -----------------
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision("high")
model = MCQABERT().cuda()
# Optional: compile for extra speed (PyTorchΒ β₯β―2.1)
if hasattr(torch, "compile"):
model = torch.compile(model)
# AdamWΒ (fused=True on PyTorchΒ β₯β―2.2, else falls back)
fused_ok = "fused" in torch.optim.AdamW.__init__.__code__.co_varnames
optimizer = torch.optim.AdamW(
model.parameters(), lr=3e-5, fused=fused_ok
)
total_steps = len(train_loader) * 3 # 3 epochs
scheduler = get_linear_schedule_with_warmup(
optimizer, int(0.1 * total_steps), total_steps
)
# ---------------- training -----------------
for epoch in range(1, 4):
model.train()
t0 = time.time()
for batch in train_loader:
ids = batch["input_ids"].cuda(non_blocking=True)
mask = batch["attention_mask"].cuda(non_blocking=True)
label = batch["label"].float().cuda(non_blocking=True)
logits = model(ids, mask)
loss = F.binary_cross_entropy_with_logits(logits, label)
loss.backward()
optimizer.step()
scheduler.step()
optimizer.zero_grad(set_to_none=True)
dur = time.time() - t0
bin_acc, mc_acc = run_split(model, val_loader)
print(f"Epoch {epoch}: "
f"valβCLS={bin_acc:.3f} | valβ4way={mc_acc:.3f} "
f"| time={dur/60:.1f}β―min")
# ---------------- test -----------------
_, test_acc = run_split(model, test_loader)
mem = torch.cuda.max_memory_allocated() / (1024 ** 3)
print(f"Test 4βway accuracy = {test_acc:.3f}")
print(f"Peak GPU memory = {mem:.1f}β―GB")
if __name__ == "__main__":
main()
|