|
import ujson
|
|
|
|
from collections import defaultdict
|
|
from colbert.utils.runs import Run
|
|
|
|
|
|
class Metrics:
|
|
def __init__(self, mrr_depths: set, recall_depths: set, success_depths: set, total_queries=None):
|
|
self.results = {}
|
|
self.mrr_sums = {depth: 0.0 for depth in mrr_depths}
|
|
self.recall_sums = {depth: 0.0 for depth in recall_depths}
|
|
self.success_sums = {depth: 0.0 for depth in success_depths}
|
|
self.total_queries = total_queries
|
|
|
|
self.max_query_idx = -1
|
|
self.num_queries_added = 0
|
|
|
|
def add(self, query_idx, query_key, ranking, gold_positives):
|
|
self.num_queries_added += 1
|
|
|
|
assert query_key not in self.results
|
|
assert len(self.results) <= query_idx
|
|
assert len(set(gold_positives)) == len(gold_positives)
|
|
assert len(set([pid for _, pid, _ in ranking])) == len(ranking)
|
|
|
|
self.results[query_key] = ranking
|
|
|
|
positives = [i for i, (_, pid, _) in enumerate(ranking) if pid in gold_positives]
|
|
|
|
if len(positives) == 0:
|
|
return
|
|
|
|
for depth in self.mrr_sums:
|
|
first_positive = positives[0]
|
|
self.mrr_sums[depth] += (1.0 / (first_positive+1.0)) if first_positive < depth else 0.0
|
|
|
|
for depth in self.success_sums:
|
|
first_positive = positives[0]
|
|
self.success_sums[depth] += 1.0 if first_positive < depth else 0.0
|
|
|
|
for depth in self.recall_sums:
|
|
num_positives_up_to_depth = len([pos for pos in positives if pos < depth])
|
|
self.recall_sums[depth] += num_positives_up_to_depth / len(gold_positives)
|
|
|
|
def print_metrics(self, query_idx):
|
|
for depth in sorted(self.mrr_sums):
|
|
print("MRR@" + str(depth), "=", self.mrr_sums[depth] / (query_idx+1.0))
|
|
|
|
for depth in sorted(self.success_sums):
|
|
print("Success@" + str(depth), "=", self.success_sums[depth] / (query_idx+1.0))
|
|
|
|
for depth in sorted(self.recall_sums):
|
|
print("Recall@" + str(depth), "=", self.recall_sums[depth] / (query_idx+1.0))
|
|
|
|
def log(self, query_idx):
|
|
assert query_idx >= self.max_query_idx
|
|
self.max_query_idx = query_idx
|
|
|
|
Run.log_metric("ranking/max_query_idx", query_idx, query_idx)
|
|
Run.log_metric("ranking/num_queries_added", self.num_queries_added, query_idx)
|
|
|
|
for depth in sorted(self.mrr_sums):
|
|
score = self.mrr_sums[depth] / (query_idx+1.0)
|
|
Run.log_metric("ranking/MRR." + str(depth), score, query_idx)
|
|
|
|
for depth in sorted(self.success_sums):
|
|
score = self.success_sums[depth] / (query_idx+1.0)
|
|
Run.log_metric("ranking/Success." + str(depth), score, query_idx)
|
|
|
|
for depth in sorted(self.recall_sums):
|
|
score = self.recall_sums[depth] / (query_idx+1.0)
|
|
Run.log_metric("ranking/Recall." + str(depth), score, query_idx)
|
|
|
|
def output_final_metrics(self, path, query_idx, num_queries):
|
|
assert query_idx + 1 == num_queries
|
|
assert num_queries == self.total_queries
|
|
|
|
if self.max_query_idx < query_idx:
|
|
self.log(query_idx)
|
|
|
|
self.print_metrics(query_idx)
|
|
|
|
output = defaultdict(dict)
|
|
|
|
for depth in sorted(self.mrr_sums):
|
|
score = self.mrr_sums[depth] / (query_idx+1.0)
|
|
output['mrr'][depth] = score
|
|
|
|
for depth in sorted(self.success_sums):
|
|
score = self.success_sums[depth] / (query_idx+1.0)
|
|
output['success'][depth] = score
|
|
|
|
for depth in sorted(self.recall_sums):
|
|
score = self.recall_sums[depth] / (query_idx+1.0)
|
|
output['recall'][depth] = score
|
|
|
|
with open(path, 'w') as f:
|
|
ujson.dump(output, f, indent=4)
|
|
f.write('\n')
|
|
|
|
|
|
def evaluate_recall(qrels, queries, topK_pids):
|
|
if qrels is None:
|
|
return
|
|
|
|
assert set(qrels.keys()) == set(queries.keys())
|
|
recall_at_k = [len(set.intersection(set(qrels[qid]), set(topK_pids[qid]))) / max(1.0, len(qrels[qid]))
|
|
for qid in qrels]
|
|
recall_at_k = sum(recall_at_k) / len(qrels)
|
|
recall_at_k = round(recall_at_k, 3)
|
|
print("Recall @ maximum depth =", recall_at_k)
|
|
|
|
|
|
|
|
|