File size: 2,905 Bytes
d6585f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from email.policy import default
import argparse
from random import choices
from datasets import load_dataset
import os

def def_args(parser):
    parser.add_argument('--data_split', type=str, choices=['validation','test'], default='test')
    parser.add_argument('--dataset',type=str, choices=['nq','trivia'],default='nq')
    parser.add_argument('--output_path', type=str,
                        default='./augmented_topics.tsv', help="output txt path")
    parser.add_argument('--k', type=int, default=1,
                        help="first k augmentations to be added to the query")
    parser.add_argument('--answers', action='store_true', default=False)
    parser.add_argument('--titles', action='store_true', default=False)
    parser.add_argument('--sentences', action='store_true', default=False)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Query augmentations.')
    def_args(parser)
    args = parser.parse_args()
    final_list = []
    json_list = []
    anserini_path = os.environ['ANSERINI']
    data_path = ''
    if args.dataset == 'nq':
        if args.data_split == 'validation':
            data_path = os.path.join(anserini_path,'src/main/resources/topics-and-qrels/topics.nq.dev.txt')
        elif args.data_split == 'test':
            data_path = os.path.join(anserini_path,'src/main/resources/topics-and-qrels/topics.nq.test.txt')
    elif args.dataset == 'trivia':
        if args.data_split == 'validation':
            data_path = os.path.join(anserini_path,'src/main/resources/topics-and-qrels/topics.dpr.trivia.dev.txt')
        elif args.data_split == 'test':
            data_path = os.path.join(anserini_path,'src/main/resources/topics-and-qrels/topics.dpr.trivia.test.txt')
    
    dataset = 'castorini/triviaqa_gar-t5_expansions' if args.dataset == 'trivia' else 'castorini/nq_gar-t5_expansions'
    with open(data_path, 'r') as file:
        file = file.readlines()
        concatenated = list(map(lambda x: x.split('\t'), file))

    data_files = {"dev":"dev/dev.jsonl", "test": "test/test.jsonl"}
    json_list = load_dataset(dataset, data_files=data_files)[args.data_split]
    for i in range(len(json_list)):
        temp_list = []
        temp2_list = []
        temp_list.append(json_list[i]['id']+'\t')
        temp_list.append(concatenated[i][0] + ' ')

        if args.answers:
            temp2_list.append(
                ' '.join(json_list[i]['predicted_answers'][:args.k]))
        if args.titles:
            temp2_list.append(
                ' '.join(json_list[i]['predicted_titles'][:args.k]))
        if args.sentences:
            temp2_list.append(
                ' '.join(json_list[i]['predicted_sentences'][:args.k]))

        final_list.append(''.join(temp_list) + ' '.join(temp2_list)+'\n')

    with open(args.output_path, 'w') as output_file:
        output_file.writelines(final_list)
    print("Done")