from colbert.infra.run import Run from colbert.data.collection import Collection import os import sys import git import tqdm import ujson import random from argparse import ArgumentParser from multiprocessing import Pool from colbert.utils.utils import groupby_first_item, print_message from utility.utils.qa_loaders import load_qas_, load_collection_ from utility.utils.save_metadata import format_metadata, get_metadata from utility.evaluate.annotate_EM_helpers import * from colbert.data.ranking import Ranking class AnnotateEM: def __init__(self, collection, qas): # TODO: These should just be Queries! But Queries needs to support looking up answers as qid2answers below. qas = load_qas_(qas) collection = Collection.cast(collection) # .tolist() #load_collection_(collection, retain_titles=True) self.parallel_pool = Pool(30) print_message('#> Tokenize the answers in the Q&As in parallel...') qas = list(self.parallel_pool.map(tokenize_all_answers, qas)) qid2answers = {qid: tok_answers for qid, _, tok_answers in qas} assert len(qas) == len(qid2answers), (len(qas), len(qid2answers)) self.qas, self.collection = qas, collection self.qid2answers = qid2answers def annotate(self, ranking): rankings = Ranking.cast(ranking) # print(len(rankings), rankings[0]) print_message('#> Lookup passages from PIDs...') expanded_rankings = [(qid, pid, rank, self.collection[pid], self.qid2answers[qid]) for qid, pid, rank, *_ in rankings.tolist()] print_message('#> Assign labels in parallel...') labeled_rankings = list(self.parallel_pool.map(assign_label_to_passage, enumerate(expanded_rankings))) # Dump output. self.qid2rankings = groupby_first_item(labeled_rankings) self.num_judged_queries, self.num_ranked_queries = check_sizes(self.qid2answers, self.qid2rankings) # Evaluation metrics and depths. self.success, self.counts = self._compute_labels(self.qid2answers, self.qid2rankings) print(rankings.provenance(), self.success) return Ranking(data=self.qid2rankings, provenance=("AnnotateEM", rankings.provenance())) def _compute_labels(self, qid2answers, qid2rankings): cutoffs = [1, 5, 10, 20, 30, 50, 100, 1000, 'all'] success = {cutoff: 0.0 for cutoff in cutoffs} counts = {cutoff: 0.0 for cutoff in cutoffs} for qid in qid2answers: if qid not in qid2rankings: continue prev_rank = 0 # ranks should start at one (i.e., and not zero) labels = [] for pid, rank, label in qid2rankings[qid]: assert rank == prev_rank+1, (qid, pid, (prev_rank, rank)) prev_rank = rank labels.append(label) for cutoff in cutoffs: if cutoff != 'all': success[cutoff] += sum(labels[:cutoff]) > 0 counts[cutoff] += sum(labels[:cutoff]) else: success[cutoff] += sum(labels) > 0 counts[cutoff] += sum(labels) return success, counts def save(self, new_path): print_message("#> Dumping output to", new_path, "...") Ranking(data=self.qid2rankings).save(new_path) # Dump metrics. with Run().open(f'{new_path}.metrics', 'w') as f: d = {'num_ranked_queries': self.num_ranked_queries, 'num_judged_queries': self.num_judged_queries} extra = '__WARNING' if self.num_judged_queries != self.num_ranked_queries else '' d[f'success{extra}'] = {k: v / self.num_judged_queries for k, v in self.success.items()} d[f'counts{extra}'] = {k: v / self.num_judged_queries for k, v in self.counts.items()} # d['arguments'] = get_metadata(args) # TODO: Need arguments... f.write(format_metadata(d) + '\n') if __name__ == '__main__': r = '/future/u/okhattab/root/unit/experiments/2021.08/retrieve.py/2021-09-04_15.50.02/ranking.tsv' r = '/future/u/okhattab/root/unit/experiments/2021.08/retrieve.py/2021-09-04_15.59.37/ranking.tsv' r = sys.argv[1] a = AnnotateEM(collection='/future/u/okhattab/root/unit/data/NQ-mini/collection.tsv', qas='/future/u/okhattab/root/unit/data/NQ-mini/dev/qas.json') a.annotate(ranking=r)