Spaces:
Runtime error
Runtime error
# train.py | |
# -------------------------------------------------- | |
# Training script with: | |
# * bugβfree evaluation (no IndexError) | |
# * faster throughput (preβtokenised data, DataLoader workers) | |
# * higher GPU utilisation (larger batch, torch.compile, TF32) | |
# -------------------------------------------------- | |
import os, math, time, torch, torch.nn.functional as F | |
from torch.utils.data import DataLoader | |
from transformers import ( | |
AutoTokenizer, | |
get_linear_schedule_with_warmup, | |
) | |
from mcqa_dataset import MCQADataset | |
from mcqa_bert import MCQABERT | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def run_split(model, loader): | |
""" | |
Returns: (binary CLS accuracy, 4βway accuracy) | |
Bug fixed β no more IndexError. | |
""" | |
model.eval() | |
bin_correct = tot = 0 | |
per_qid = {} # {qid: [(cid, logit, gold_flag), β¦]} | |
with torch.no_grad(): | |
for batch in loader: | |
ids = batch["input_ids"].cuda() | |
mask = batch["attention_mask"].cuda() | |
label = batch["label"].cuda() | |
logits = model(ids, mask) # (B) | |
preds = (torch.sigmoid(logits) > 0.5).long() | |
bin_correct += (preds == label).sum().item() | |
tot += len(label) | |
# stash logits for 4βway metric | |
for qid, cid, logit, gold_flag in zip( | |
batch["qid"], batch["cid"], logits.cpu(), batch["label"] | |
): | |
per_qid.setdefault(qid, []).append( | |
(cid, logit.item(), gold_flag.item()) | |
) | |
correct4 = 0 | |
for qid, opts in per_qid.items(): | |
pred_cid = max(opts, key=lambda x: x[1])[0] # highest logit | |
gold_cid = [cid for cid, _, flag in opts if flag == 1][0] | |
if pred_cid == gold_cid: | |
correct4 += 1 | |
return bin_correct / tot, correct4 / len(per_qid) | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def main(): | |
# ---------------- data ----------------- | |
tok = AutoTokenizer.from_pretrained("bert-base-uncased") | |
train_ds = MCQADataset("train_complete.jsonl", tok) | |
val_ds = MCQADataset("valid_complete.jsonl", tok) | |
test_ds = MCQADataset("test_complete.jsonl", tok) | |
train_loader = DataLoader( | |
train_ds, batch_size=64, shuffle=True, | |
num_workers=4, pin_memory=True, persistent_workers=True | |
) | |
val_loader = DataLoader( | |
val_ds, batch_size=128, num_workers=4, | |
pin_memory=True, persistent_workers=True | |
) | |
test_loader = DataLoader( | |
test_ds, batch_size=128, num_workers=4, | |
pin_memory=True, persistent_workers=True | |
) | |
# ---------------- model ----------------- | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.set_float32_matmul_precision("high") | |
model = MCQABERT().cuda() | |
# Optional: compile for extra speed (PyTorchΒ β₯β―2.1) | |
if hasattr(torch, "compile"): | |
model = torch.compile(model) | |
# AdamWΒ (fused=True on PyTorchΒ β₯β―2.2, else falls back) | |
fused_ok = "fused" in torch.optim.AdamW.__init__.__code__.co_varnames | |
optimizer = torch.optim.AdamW( | |
model.parameters(), lr=3e-5, fused=fused_ok | |
) | |
total_steps = len(train_loader) * 3 # 3 epochs | |
scheduler = get_linear_schedule_with_warmup( | |
optimizer, int(0.1 * total_steps), total_steps | |
) | |
# ---------------- training ----------------- | |
for epoch in range(1, 4): | |
model.train() | |
t0 = time.time() | |
for batch in train_loader: | |
ids = batch["input_ids"].cuda(non_blocking=True) | |
mask = batch["attention_mask"].cuda(non_blocking=True) | |
label = batch["label"].float().cuda(non_blocking=True) | |
logits = model(ids, mask) | |
loss = F.binary_cross_entropy_with_logits(logits, label) | |
loss.backward() | |
optimizer.step() | |
scheduler.step() | |
optimizer.zero_grad(set_to_none=True) | |
dur = time.time() - t0 | |
bin_acc, mc_acc = run_split(model, val_loader) | |
print(f"Epoch {epoch}: " | |
f"valβCLS={bin_acc:.3f} | valβ4way={mc_acc:.3f} " | |
f"| time={dur/60:.1f}β―min") | |
# ---------------- test ----------------- | |
_, test_acc = run_split(model, test_loader) | |
mem = torch.cuda.max_memory_allocated() / (1024 ** 3) | |
print(f"Test 4βway accuracy = {test_acc:.3f}") | |
print(f"Peak GPU memory = {mem:.1f}β―GB") | |
if __name__ == "__main__": | |
main() | |