File size: 5,200 Bytes
828992f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
"""
Example: --positives 5,50 1,1000 ~~> best-5 (in top-50) + best-1 (in top-1000)
"""
import os
import sys
import git
import tqdm
import ujson
import random
from argparse import ArgumentParser
from colbert.utils.utils import print_message, load_ranking, groupby_first_item, create_directory
from utility.utils.save_metadata import save_metadata
MAX_NUM_TRIPLES = 40_000_000
def sample_negatives(negatives, num_sampled, biased=None):
assert biased in [None, 100, 200], "NOTE: We bias 50% from the top-200 negatives, if there are twice or more."
num_sampled = min(len(negatives), num_sampled)
if biased and num_sampled < len(negatives):
assert num_sampled % 2 == 0, num_sampled
num_sampled_top100 = num_sampled // 2
num_sampled_rest = num_sampled - num_sampled_top100
oversampled, undersampled = negatives[:biased], negatives[biased:]
if len(oversampled) < len(undersampled):
return random.sample(oversampled, num_sampled_top100) + random.sample(undersampled, num_sampled_rest)
return random.sample(negatives, num_sampled)
def sample_for_query(qid, ranking, args_positives, depth, permissive, biased):
"""
Requires that the ranks are sorted per qid.
"""
positives, negatives, triples = [], [], []
for pid, rank, *_, label in ranking:
assert rank >= 1, f"ranks should start at 1 \t\t got rank = {rank}"
assert label in [0, 1]
if rank > depth:
break
if label:
take_this_positive = any(rank <= maxDepth and len(positives) < maxBest for maxBest, maxDepth in args_positives)
if take_this_positive:
positives.append((pid, 0))
elif permissive:
positives.append((pid, rank)) # utilize with a few negatives, starting at (next) rank
else:
negatives.append(pid)
for pos, neg_start in positives:
num_sampled = 100 if neg_start == 0 else 5
negatives_ = negatives[neg_start:]
biased_ = biased if neg_start == 0 else None
for neg in sample_negatives(negatives_, num_sampled, biased=biased_):
triples.append((qid, pos, neg))
return triples
def main(args):
try:
rankings = load_ranking(args.ranking, types=[int, int, int, float, int])
except:
rankings = load_ranking(args.ranking, types=[int, int, int, int])
print_message("#> Group by QID")
qid2rankings = groupby_first_item(tqdm.tqdm(rankings))
Triples = []
NonEmptyQIDs = 0
for processing_idx, qid in enumerate(qid2rankings):
l = sample_for_query(qid, qid2rankings[qid], args.positives, args.depth, args.permissive, args.biased)
NonEmptyQIDs += (len(l) > 0)
Triples.extend(l)
if processing_idx % (10_000) == 0:
print_message(f"#> Done with {processing_idx+1} questions!\t\t "
f"{str(len(Triples) / 1000)}k triples for {NonEmptyQIDs} unqiue QIDs.")
print_message(f"#> Sub-sample the triples (if > {MAX_NUM_TRIPLES})..")
print_message(f"#> len(Triples) = {len(Triples)}")
if len(Triples) > MAX_NUM_TRIPLES:
Triples = random.sample(Triples, MAX_NUM_TRIPLES)
### Prepare the triples ###
print_message("#> Shuffling the triples...")
random.shuffle(Triples)
print_message("#> Writing {}M examples to file.".format(len(Triples) / 1000.0 / 1000.0))
with open(args.output, 'w') as f:
for example in Triples:
ujson.dump(example, f)
f.write('\n')
save_metadata(f'{args.output}.meta', args)
print('\n\n', args, '\n\n')
print(args.output)
print_message("#> Done.")
if __name__ == "__main__":
parser = ArgumentParser(description='Create training triples from ranked list.')
# Input / Output Arguments
parser.add_argument('--ranking', dest='ranking', required=True, type=str)
parser.add_argument('--output', dest='output', required=True, type=str)
# Weak Supervision Arguments.
parser.add_argument('--positives', dest='positives', required=True, nargs='+')
parser.add_argument('--depth', dest='depth', required=True, type=int) # for negatives
parser.add_argument('--permissive', dest='permissive', default=False, action='store_true')
# parser.add_argument('--biased', dest='biased', default=False, action='store_true')
parser.add_argument('--biased', dest='biased', default=None, type=int)
parser.add_argument('--seed', dest='seed', required=False, default=12345, type=int)
args = parser.parse_args()
random.seed(args.seed)
assert not os.path.exists(args.output), args.output
args.positives = [list(map(int, configuration.split(','))) for configuration in args.positives]
assert all(len(x) == 2 for x in args.positives)
assert all(maxBest <= maxDepth for maxBest, maxDepth in args.positives), args.positives
create_directory(os.path.dirname(args.output))
assert args.biased in [None, 100, 200]
main(args)
|