Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 1,158 Bytes
0c3992e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import os.path as osp
from src.benchmarks.qa_datasets import AmazonSTaRKDataset, PrimeKGSTaRKDataset, MAGSTaRKDataset, STaRKDataset
def get_qa_dataset(name, root='data/'):
qa_root = osp.join(root, name)
if name == 'amazon':
split_dir = osp.join(qa_root, 'split')
stark_qa_dir = osp.join(qa_root, 'stark_qa')
dataset = AmazonSTaRKDataset(stark_qa_dir, split_dir)
if name == 'primekg':
split_dir = osp.join(qa_root, 'split')
stark_qa_dir = osp.join(qa_root, 'stark_qa')
dataset = PrimeKGSTaRKDataset(stark_qa_dir, split_dir)
if name == 'mag':
split_dir = osp.join(qa_root, 'split')
stark_qa_dir = osp.join(qa_root, 'stark_qa')
dataset = MAGSTaRKDataset(stark_qa_dir, split_dir)
else:
try:
print('loading dataset from external data')
split_dir = osp.join(qa_root, 'split')
stark_qa_dir = osp.join(qa_root, 'stark_qa')
dataset = STaRKDataset(stark_qa_dir, split_dir)
except Exception as e:
print('Please check dataset name, path, or format\n')
raise e
return dataset
|