|
"""
|
|
Divide a query set into two.
|
|
"""
|
|
|
|
import os
|
|
import math
|
|
import ujson
|
|
import random
|
|
|
|
from argparse import ArgumentParser
|
|
from collections import OrderedDict
|
|
from colbert.utils.utils import print_message
|
|
|
|
|
|
def main(args):
|
|
random.seed(12345)
|
|
|
|
"""
|
|
Load the queries
|
|
"""
|
|
Queries = OrderedDict()
|
|
|
|
print_message(f"#> Loading queries from {args.input}..")
|
|
with open(args.input) as f:
|
|
for line in f:
|
|
qid, query = line.strip().split('\t')
|
|
|
|
assert qid not in Queries
|
|
Queries[qid] = query
|
|
|
|
"""
|
|
Apply the splitting
|
|
"""
|
|
size_a = len(Queries) - args.holdout
|
|
size_b = args.holdout
|
|
size_a, size_b = max(size_a, size_b), min(size_a, size_b)
|
|
|
|
assert size_a > 0 and size_b > 0, (len(Queries), size_a, size_b)
|
|
|
|
print_message(f"#> Deterministically splitting the queries into ({size_a}, {size_b})-sized splits.")
|
|
|
|
keys = list(Queries.keys())
|
|
sample_b_indices = sorted(list(random.sample(range(len(keys)), size_b)))
|
|
sample_a_indices = sorted(list(set.difference(set(list(range(len(keys)))), set(sample_b_indices))))
|
|
|
|
assert len(sample_a_indices) == size_a
|
|
assert len(sample_b_indices) == size_b
|
|
|
|
sample_a = [keys[idx] for idx in sample_a_indices]
|
|
sample_b = [keys[idx] for idx in sample_b_indices]
|
|
|
|
"""
|
|
Write the output
|
|
"""
|
|
|
|
output_path_a = f'{args.input}.a'
|
|
output_path_b = f'{args.input}.b'
|
|
|
|
assert not os.path.exists(output_path_a), output_path_a
|
|
assert not os.path.exists(output_path_b), output_path_b
|
|
|
|
print_message(f"#> Writing the splits out to {output_path_a} and {output_path_b} ...")
|
|
|
|
for output_path, sample in [(output_path_a, sample_a), (output_path_b, sample_b)]:
|
|
with open(output_path, 'w') as f:
|
|
for qid in sample:
|
|
query = Queries[qid]
|
|
line = '\t'.join([qid, query]) + '\n'
|
|
f.write(line)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = ArgumentParser(description="queries_split.")
|
|
|
|
|
|
parser.add_argument('--input', dest='input', required=True)
|
|
parser.add_argument('--holdout', dest='holdout', required=True, type=int)
|
|
|
|
args = parser.parse_args()
|
|
|
|
main(args)
|
|
|