File size: 1,431 Bytes
0c3992e
 
 
 
a00d62c
0c3992e
 
 
 
a00d62c
 
 
0c3992e
 
a00d62c
 
 
0c3992e
 
a00d62c
 
0c3992e
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import os.path as osp
from src.benchmarks.qa_datasets import AmazonSTaRKDataset, PrimeKGSTaRKDataset, MAGSTaRKDataset, STaRKDataset


def get_qa_dataset(name, root='data/', human_generated_eval=False):
    qa_root = osp.join(root, name)
    if name == 'amazon':
        split_dir = osp.join(qa_root, 'split')
        stark_qa_dir = osp.join(qa_root, 'stark_qa')
        dataset = AmazonSTaRKDataset(stark_qa_dir, split_dir, 
                                     human_generated_eval=human_generated_eval)
    elif name == 'primekg':
        split_dir = osp.join(qa_root, 'split')
        stark_qa_dir = osp.join(qa_root, 'stark_qa')
        dataset = PrimeKGSTaRKDataset(stark_qa_dir, split_dir, 
                                      human_generated_eval=human_generated_eval)
    elif name == 'mag':
        split_dir = osp.join(qa_root, 'split')
        stark_qa_dir = osp.join(qa_root, 'stark_qa')
        dataset = MAGSTaRKDataset(stark_qa_dir, split_dir, 
                                  human_generated_eval=human_generated_eval)
    else:
        try:
            print('loading dataset from external data')
            split_dir = osp.join(qa_root, 'split')
            stark_qa_dir = osp.join(qa_root, 'stark_qa')
            dataset = STaRKDataset(stark_qa_dir, split_dir)
        except Exception as e:
            print('Please check dataset name, path, or format\n')
            raise e
    return dataset