E2-F5-TTS / train_q1.py
kevinwang676's picture
Create train_q1.py
d4069e8 verified
raw
history blame
4.88 kB
# train.py
# --------------------------------------------------
# Training script with:
# * bug‑free evaluation (no IndexError)
# * faster throughput (pre‑tokenised data, DataLoader workers)
# * higher GPU utilisation (larger batch, torch.compile, TF32)
# --------------------------------------------------
import os, math, time, torch, torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import (
AutoTokenizer,
get_linear_schedule_with_warmup,
)
from mcqa_dataset import MCQADataset
from mcqa_bert import MCQABERT
# ──────────────────────────────────────────────────────────────────────────
def run_split(model, loader):
"""
Returns: (binary CLS accuracy, 4‑way accuracy)
Bug fixed – no more IndexError.
"""
model.eval()
bin_correct = tot = 0
per_qid = {} # {qid: [(cid, logit, gold_flag), …]}
with torch.no_grad():
for batch in loader:
ids = batch["input_ids"].cuda()
mask = batch["attention_mask"].cuda()
label = batch["label"].cuda()
logits = model(ids, mask) # (B)
preds = (torch.sigmoid(logits) > 0.5).long()
bin_correct += (preds == label).sum().item()
tot += len(label)
# stash logits for 4‑way metric
for qid, cid, logit, gold_flag in zip(
batch["qid"], batch["cid"], logits.cpu(), batch["label"]
):
per_qid.setdefault(qid, []).append(
(cid, logit.item(), gold_flag.item())
)
correct4 = 0
for qid, opts in per_qid.items():
pred_cid = max(opts, key=lambda x: x[1])[0] # highest logit
gold_cid = [cid for cid, _, flag in opts if flag == 1][0]
if pred_cid == gold_cid:
correct4 += 1
return bin_correct / tot, correct4 / len(per_qid)
# ──────────────────────────────────────────────────────────────────────────
def main():
# ---------------- data -----------------
tok = AutoTokenizer.from_pretrained("bert-base-uncased")
train_ds = MCQADataset("train_complete.jsonl", tok)
val_ds = MCQADataset("valid_complete.jsonl", tok)
test_ds = MCQADataset("test_complete.jsonl", tok)
train_loader = DataLoader(
train_ds, batch_size=64, shuffle=True,
num_workers=4, pin_memory=True, persistent_workers=True
)
val_loader = DataLoader(
val_ds, batch_size=128, num_workers=4,
pin_memory=True, persistent_workers=True
)
test_loader = DataLoader(
test_ds, batch_size=128, num_workers=4,
pin_memory=True, persistent_workers=True
)
# ---------------- model -----------------
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision("high")
model = MCQABERT().cuda()
# Optional: compile for extra speed (PyTorchΒ β‰₯β€―2.1)
if hasattr(torch, "compile"):
model = torch.compile(model)
# AdamWΒ (fused=True on PyTorchΒ β‰₯β€―2.2, else falls back)
fused_ok = "fused" in torch.optim.AdamW.__init__.__code__.co_varnames
optimizer = torch.optim.AdamW(
model.parameters(), lr=3e-5, fused=fused_ok
)
total_steps = len(train_loader) * 3 # 3 epochs
scheduler = get_linear_schedule_with_warmup(
optimizer, int(0.1 * total_steps), total_steps
)
# ---------------- training -----------------
for epoch in range(1, 4):
model.train()
t0 = time.time()
for batch in train_loader:
ids = batch["input_ids"].cuda(non_blocking=True)
mask = batch["attention_mask"].cuda(non_blocking=True)
label = batch["label"].float().cuda(non_blocking=True)
logits = model(ids, mask)
loss = F.binary_cross_entropy_with_logits(logits, label)
loss.backward()
optimizer.step()
scheduler.step()
optimizer.zero_grad(set_to_none=True)
dur = time.time() - t0
bin_acc, mc_acc = run_split(model, val_loader)
print(f"Epoch {epoch}: "
f"val‑CLS={bin_acc:.3f} | val‑4way={mc_acc:.3f} "
f"| time={dur/60:.1f}β€―min")
# ---------------- test -----------------
_, test_acc = run_split(model, test_loader)
mem = torch.cuda.max_memory_allocated() / (1024 ** 3)
print(f"Test 4‑way accuracy = {test_acc:.3f}")
print(f"Peak GPU memory = {mem:.1f}β€―GB")
if __name__ == "__main__":
main()