Spaces:
Runtime error
Runtime error
File size: 4,453 Bytes
58627fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
from colbert.infra.run import Run
from colbert.data.collection import Collection
import os
import sys
import git
import tqdm
import ujson
import random
from argparse import ArgumentParser
from multiprocessing import Pool
from colbert.utils.utils import groupby_first_item, print_message
from utility.utils.qa_loaders import load_qas_, load_collection_
from utility.utils.save_metadata import format_metadata, get_metadata
from utility.evaluate.annotate_EM_helpers import *
from colbert.data.ranking import Ranking
class AnnotateEM:
def __init__(self, collection, qas):
# TODO: These should just be Queries! But Queries needs to support looking up answers as qid2answers below.
qas = load_qas_(qas)
collection = Collection.cast(collection) # .tolist() #load_collection_(collection, retain_titles=True)
self.parallel_pool = Pool(30)
print_message('#> Tokenize the answers in the Q&As in parallel...')
qas = list(self.parallel_pool.map(tokenize_all_answers, qas))
qid2answers = {qid: tok_answers for qid, _, tok_answers in qas}
assert len(qas) == len(qid2answers), (len(qas), len(qid2answers))
self.qas, self.collection = qas, collection
self.qid2answers = qid2answers
def annotate(self, ranking):
rankings = Ranking.cast(ranking)
# print(len(rankings), rankings[0])
print_message('#> Lookup passages from PIDs...')
expanded_rankings = [(qid, pid, rank, self.collection[pid], self.qid2answers[qid])
for qid, pid, rank, *_ in rankings.tolist()]
print_message('#> Assign labels in parallel...')
labeled_rankings = list(self.parallel_pool.map(assign_label_to_passage, enumerate(expanded_rankings)))
# Dump output.
self.qid2rankings = groupby_first_item(labeled_rankings)
self.num_judged_queries, self.num_ranked_queries = check_sizes(self.qid2answers, self.qid2rankings)
# Evaluation metrics and depths.
self.success, self.counts = self._compute_labels(self.qid2answers, self.qid2rankings)
print(rankings.provenance(), self.success)
return Ranking(data=self.qid2rankings, provenance=("AnnotateEM", rankings.provenance()))
def _compute_labels(self, qid2answers, qid2rankings):
cutoffs = [1, 5, 10, 20, 30, 50, 100, 1000, 'all']
success = {cutoff: 0.0 for cutoff in cutoffs}
counts = {cutoff: 0.0 for cutoff in cutoffs}
for qid in qid2answers:
if qid not in qid2rankings:
continue
prev_rank = 0 # ranks should start at one (i.e., and not zero)
labels = []
for pid, rank, label in qid2rankings[qid]:
assert rank == prev_rank+1, (qid, pid, (prev_rank, rank))
prev_rank = rank
labels.append(label)
for cutoff in cutoffs:
if cutoff != 'all':
success[cutoff] += sum(labels[:cutoff]) > 0
counts[cutoff] += sum(labels[:cutoff])
else:
success[cutoff] += sum(labels) > 0
counts[cutoff] += sum(labels)
return success, counts
def save(self, new_path):
print_message("#> Dumping output to", new_path, "...")
Ranking(data=self.qid2rankings).save(new_path)
# Dump metrics.
with Run().open(f'{new_path}.metrics', 'w') as f:
d = {'num_ranked_queries': self.num_ranked_queries, 'num_judged_queries': self.num_judged_queries}
extra = '__WARNING' if self.num_judged_queries != self.num_ranked_queries else ''
d[f'success{extra}'] = {k: v / self.num_judged_queries for k, v in self.success.items()}
d[f'counts{extra}'] = {k: v / self.num_judged_queries for k, v in self.counts.items()}
# d['arguments'] = get_metadata(args) # TODO: Need arguments...
f.write(format_metadata(d) + '\n')
if __name__ == '__main__':
r = '/future/u/okhattab/root/unit/experiments/2021.08/retrieve.py/2021-09-04_15.50.02/ranking.tsv'
r = '/future/u/okhattab/root/unit/experiments/2021.08/retrieve.py/2021-09-04_15.59.37/ranking.tsv'
r = sys.argv[1]
a = AnnotateEM(collection='/future/u/okhattab/root/unit/data/NQ-mini/collection.tsv',
qas='/future/u/okhattab/root/unit/data/NQ-mini/dev/qas.json')
a.annotate(ranking=r)
|