Hi-ToM_Dataset / Hi-ToM.py
umwyf's picture
Upload Hi-ToM.py
a70e840
raw
history blame
2.12 kB
from datasets import DatasetBuilder, DatasetInfo, SplitGenerators, DownloadManager
from datasets.features import Features, ClassLabel, Sequence, Value
class MyCustomDataset(DatasetBuilder):
VERSION = datasets.Version("1.0.0")
def _info(self):
return DatasetInfo(
description="My custom dataset for tracking objects.",
features=Features({
"prompting_type": Value("string"),
"deception": Value("bool"),
"story_length": Value("int32"),
"question_order": Value("int32"),
"sample_id": Value("int32"),
"story": Value("string"),
"question": Value("string"),
"choices": Value("string"),
"answer": Value("string"),
}),
supervised_keys=None,
homepage="https://github.com/ying-hui-he/Hi-ToM_dataset",
citation=CITATION,
)
def _split_generators(self, dl_manager: DownloadManager):
downloaded_files = dl_manager.download_and_extract({
"data_file": "https://github.com/ying-hui-he/Hi-ToM_dataset/blob/main/Hi-ToM_data.json"
})
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepath": downloaded_files["data_file"],
},
),
]
def _generate_examples(self, filepath):
with open(filepath, encoding="utf-8") as f:
data = json.load(f)
for id, item in enumerate(data["data"]):
yield id, {
"prompting_type": item["prompting_type"],
"deception": item["deception"],
"story_length": item["story_length"],
"question_order": item["question_order"],
"sample_id": item["sample_id"],
"story": item["story"],
"question": item["question"],
"choices": item["choices"],
"answer": item["answer"],
}