Spaces:
Runtime error
Runtime error
# mcqa_dataset.py | |
# -------------------------------------------------- | |
# Pre‑tokenised dataset for 4‑choice MCQA | |
# -------------------------------------------------- | |
import json | |
import torch | |
from torch.utils.data import Dataset | |
class MCQADataset(Dataset): | |
""" | |
Each item returns: | |
input_ids, attention_mask : LongTensor (max_len) | |
label : 0/1 (1 → correct choice) | |
qid, cid : strings (question id, choice id) | |
""" | |
def __init__(self, path: str, tokenizer, max_len: int = 128): | |
self.encodings, self.labels, self.qids, self.cids = [], [], [], [] | |
with open(path, encoding="utf-8") as f: | |
for line in f: | |
obj = json.loads(line) | |
stem = obj["question"]["stem"] | |
fact = obj["fact1"] | |
gold = obj["answerKey"] | |
for ch in obj["question"]["choices"]: | |
text = f"{fact} {stem} {ch['text']}" | |
enc = tokenizer( | |
text, | |
max_length=max_len, | |
truncation=True, | |
padding="max_length", | |
) | |
self.encodings.append(enc) | |
self.labels.append(1 if ch["label"] == gold else 0) | |
self.qids.append(obj["id"]) | |
self.cids.append(ch["label"]) | |
# Convert lists of dicts → dict of lists for cheaper indexing | |
self.encodings = { | |
k: [d[k] for d in self.encodings] for k in self.encodings[0] | |
} | |
# -------------------------------------------------- | |
def __len__(self): | |
return len(self.labels) | |
def __getitem__(self, idx): | |
item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()} | |
item["label"] = torch.tensor(self.labels[idx], dtype=torch.long) | |
item["qid"] = self.qids[idx] | |
item["cid"] = self.cids[idx] | |
return item | |