# mcqa_dataset.py # -------------------------------------------------- # Pre‑tokenised dataset for 4‑choice MCQA # -------------------------------------------------- import json import torch from torch.utils.data import Dataset class MCQADataset(Dataset): """ Each item returns: input_ids, attention_mask : LongTensor (max_len) label : 0/1 (1 → correct choice) qid, cid : strings (question id, choice id) """ def __init__(self, path: str, tokenizer, max_len: int = 128): self.encodings, self.labels, self.qids, self.cids = [], [], [], [] with open(path, encoding="utf-8") as f: for line in f: obj = json.loads(line) stem = obj["question"]["stem"] fact = obj["fact1"] gold = obj["answerKey"] for ch in obj["question"]["choices"]: text = f"{fact} {stem} {ch['text']}" enc = tokenizer( text, max_length=max_len, truncation=True, padding="max_length", ) self.encodings.append(enc) self.labels.append(1 if ch["label"] == gold else 0) self.qids.append(obj["id"]) self.cids.append(ch["label"]) # Convert lists of dicts → dict of lists for cheaper indexing self.encodings = { k: [d[k] for d in self.encodings] for k in self.encodings[0] } # -------------------------------------------------- def __len__(self): return len(self.labels) def __getitem__(self, idx): item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()} item["label"] = torch.tensor(self.labels[idx], dtype=torch.long) item["qid"] = self.qids[idx] item["cid"] = self.cids[idx] return item