File size: 2,003 Bytes
c2e4b4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# mcqa_dataset.py
# --------------------------------------------------
# Pre‑tokenised dataset for 4‑choice MCQA
# --------------------------------------------------
import json
import torch
from torch.utils.data import Dataset


class MCQADataset(Dataset):
    """
    Each item returns:
        input_ids, attention_mask : LongTensor  (max_len)
        label                     : 0/1  (1 → correct choice)
        qid, cid                  : strings (question id, choice id)
    """
    def __init__(self, path: str, tokenizer, max_len: int = 128):
        self.encodings, self.labels, self.qids, self.cids = [], [], [], []

        with open(path, encoding="utf-8") as f:
            for line in f:
                obj   = json.loads(line)
                stem  = obj["question"]["stem"]
                fact  = obj["fact1"]
                gold  = obj["answerKey"]

                for ch in obj["question"]["choices"]:
                    text = f"{fact} {stem} {ch['text']}"
                    enc  = tokenizer(
                        text,
                        max_length=max_len,
                        truncation=True,
                        padding="max_length",
                    )
                    self.encodings.append(enc)
                    self.labels.append(1 if ch["label"] == gold else 0)
                    self.qids.append(obj["id"])
                    self.cids.append(ch["label"])

        # Convert lists of dicts → dict of lists for cheaper indexing
        self.encodings = {
            k: [d[k] for d in self.encodings] for k in self.encodings[0]
        }

    # --------------------------------------------------

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
        item["label"] = torch.tensor(self.labels[idx], dtype=torch.long)
        item["qid"]   = self.qids[idx]
        item["cid"]   = self.cids[idx]
        return item