|
import os
|
|
import sys
|
|
import git
|
|
import tqdm
|
|
import ujson
|
|
import random
|
|
|
|
from argparse import ArgumentParser
|
|
from colbert.utils.utils import print_message, load_ranking, groupby_first_item
|
|
|
|
|
|
MAX_NUM_TRIPLES = 40_000_000
|
|
|
|
|
|
def sample_negatives(negatives, num_sampled, biased=False):
|
|
num_sampled = min(len(negatives), num_sampled)
|
|
|
|
if biased:
|
|
assert num_sampled % 2 == 0
|
|
num_sampled_top100 = num_sampled // 2
|
|
num_sampled_rest = num_sampled - num_sampled_top100
|
|
|
|
return random.sample(negatives[:100], num_sampled_top100) + random.sample(negatives[100:], num_sampled_rest)
|
|
|
|
return random.sample(negatives, num_sampled)
|
|
|
|
|
|
def sample_for_query(qid, ranking, npositives, depth_positive, depth_negative, cutoff_negative):
|
|
"""
|
|
Requires that the ranks are sorted per qid.
|
|
"""
|
|
assert npositives <= depth_positive < cutoff_negative < depth_negative
|
|
|
|
positives, negatives, triples = [], [], []
|
|
|
|
for pid, rank, *_ in ranking:
|
|
assert rank >= 1, f"ranks should start at 1 \t\t got rank = {rank}"
|
|
|
|
if rank > depth_negative:
|
|
break
|
|
|
|
if rank <= depth_positive:
|
|
positives.append(pid)
|
|
elif rank > cutoff_negative:
|
|
negatives.append(pid)
|
|
|
|
num_sampled = 100
|
|
|
|
for neg in sample_negatives(negatives, num_sampled):
|
|
positives_ = random.sample(positives, npositives)
|
|
positives_ = positives_[0] if npositives == 1 else positives_
|
|
triples.append((qid, positives_, neg))
|
|
|
|
return triples
|
|
|
|
|
|
def main(args):
|
|
rankings = load_ranking(args.ranking, types=[int, int, int, float, int])
|
|
|
|
print_message("#> Group by QID")
|
|
qid2rankings = groupby_first_item(tqdm.tqdm(rankings))
|
|
|
|
Triples = []
|
|
NonEmptyQIDs = 0
|
|
|
|
for processing_idx, qid in enumerate(qid2rankings):
|
|
l = sample_for_query(qid, qid2rankings[qid], args.positives, args.depth_positive, args.depth_negative, args.cutoff_negative)
|
|
NonEmptyQIDs += (len(l) > 0)
|
|
Triples.extend(l)
|
|
|
|
if processing_idx % (10_000) == 0:
|
|
print_message(f"#> Done with {processing_idx+1} questions!\t\t "
|
|
f"{str(len(Triples) / 1000)}k triples for {NonEmptyQIDs} unqiue QIDs.")
|
|
|
|
print_message(f"#> Sub-sample the triples (if > {MAX_NUM_TRIPLES})..")
|
|
print_message(f"#> len(Triples) = {len(Triples)}")
|
|
|
|
if len(Triples) > MAX_NUM_TRIPLES:
|
|
Triples = random.sample(Triples, MAX_NUM_TRIPLES)
|
|
|
|
|
|
print_message("#> Shuffling the triples...")
|
|
random.shuffle(Triples)
|
|
|
|
print_message("#> Writing {}M examples to file.".format(len(Triples) / 1000.0 / 1000.0))
|
|
|
|
with open(args.output, 'w') as f:
|
|
for example in Triples:
|
|
ujson.dump(example, f)
|
|
f.write('\n')
|
|
|
|
with open(f'{args.output}.meta', 'w') as f:
|
|
args.cmd = ' '.join(sys.argv)
|
|
args.git_hash = git.Repo(search_parent_directories=True).head.object.hexsha
|
|
ujson.dump(args.__dict__, f, indent=4)
|
|
f.write('\n')
|
|
|
|
print('\n\n', args, '\n\n')
|
|
print(args.output)
|
|
print_message("#> Done.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
random.seed(12345)
|
|
|
|
parser = ArgumentParser(description='Create training triples from ranked list.')
|
|
|
|
|
|
parser.add_argument('--ranking', dest='ranking', required=True, type=str)
|
|
parser.add_argument('--output', dest='output', required=True, type=str)
|
|
|
|
|
|
parser.add_argument('--positives', dest='positives', required=True, type=int)
|
|
parser.add_argument('--depth+', dest='depth_positive', required=True, type=int)
|
|
|
|
parser.add_argument('--depth-', dest='depth_negative', required=True, type=int)
|
|
parser.add_argument('--cutoff-', dest='cutoff_negative', required=True, type=int)
|
|
|
|
args = parser.parse_args()
|
|
|
|
assert not os.path.exists(args.output), args.output
|
|
|
|
main(args)
|
|
|