|
import os
|
|
import ujson
|
|
import random
|
|
|
|
from argparse import ArgumentParser
|
|
|
|
from colbert.utils.utils import print_message, create_directory, load_ranking, groupby_first_item
|
|
from utility.utils.qa_loaders import load_qas_
|
|
|
|
|
|
def main(args):
|
|
print_message("#> Loading all..")
|
|
qas = load_qas_(args.qas)
|
|
rankings = load_ranking(args.ranking)
|
|
qid2rankings = groupby_first_item(rankings)
|
|
|
|
print_message("#> Subsampling all..")
|
|
qas_sample = random.sample(qas, args.sample)
|
|
|
|
with open(args.output, 'w') as f:
|
|
for qid, *_ in qas_sample:
|
|
for items in qid2rankings[qid]:
|
|
items = [qid] + items
|
|
line = '\t'.join(map(str, items)) + '\n'
|
|
f.write(line)
|
|
|
|
print('\n\n')
|
|
print(args.output)
|
|
print("#> Done.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
random.seed(12345)
|
|
|
|
parser = ArgumentParser(description='Subsample the dev set.')
|
|
parser.add_argument('--qas', dest='qas', required=True, type=str)
|
|
parser.add_argument('--ranking', dest='ranking', required=True)
|
|
parser.add_argument('--output', dest='output', required=True)
|
|
|
|
parser.add_argument('--sample', dest='sample', default=1500, type=int)
|
|
|
|
args = parser.parse_args()
|
|
|
|
assert not os.path.exists(args.output), args.output
|
|
create_directory(os.path.dirname(args.output))
|
|
|
|
main(args)
|
|
|