E2-F5-TTS / mcqa_dataset.py
kevinwang676's picture
Create mcqa_dataset.py
c2e4b4e verified
# mcqa_dataset.py
# --------------------------------------------------
# Pre‑tokenised dataset for 4‑choice MCQA
# --------------------------------------------------
import json
import torch
from torch.utils.data import Dataset
class MCQADataset(Dataset):
"""
Each item returns:
input_ids, attention_mask : LongTensor (max_len)
label : 0/1 (1 → correct choice)
qid, cid : strings (question id, choice id)
"""
def __init__(self, path: str, tokenizer, max_len: int = 128):
self.encodings, self.labels, self.qids, self.cids = [], [], [], []
with open(path, encoding="utf-8") as f:
for line in f:
obj = json.loads(line)
stem = obj["question"]["stem"]
fact = obj["fact1"]
gold = obj["answerKey"]
for ch in obj["question"]["choices"]:
text = f"{fact} {stem} {ch['text']}"
enc = tokenizer(
text,
max_length=max_len,
truncation=True,
padding="max_length",
)
self.encodings.append(enc)
self.labels.append(1 if ch["label"] == gold else 0)
self.qids.append(obj["id"])
self.cids.append(ch["label"])
# Convert lists of dicts → dict of lists for cheaper indexing
self.encodings = {
k: [d[k] for d in self.encodings] for k in self.encodings[0]
}
# --------------------------------------------------
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
item["label"] = torch.tensor(self.labels[idx], dtype=torch.long)
item["qid"] = self.qids[idx]
item["cid"] = self.cids[idx]
return item