mColBERT / colbert /evaluation /ranking_logger.py
vjeronymo2's picture
Adding model and checkpoint
828992f
import os
from contextlib import contextmanager
from colbert.utils.utils import print_message, NullContextManager
from colbert.utils.runs import Run
class RankingLogger():
def __init__(self, directory, qrels=None, log_scores=False):
self.directory = directory
self.qrels = qrels
self.filename, self.also_save_annotations = None, None
self.log_scores = log_scores
@contextmanager
def context(self, filename, also_save_annotations=False):
assert self.filename is None
assert self.also_save_annotations is None
filename = os.path.join(self.directory, filename)
self.filename, self.also_save_annotations = filename, also_save_annotations
print_message("#> Logging ranked lists to {}".format(self.filename))
with open(filename, 'w') as f:
self.f = f
with (open(filename + '.annotated', 'w') if also_save_annotations else NullContextManager()) as g:
self.g = g
try:
yield self
finally:
pass
def log(self, qid, ranking, is_ranked=True, print_positions=[]):
print_positions = set(print_positions)
f_buffer = []
g_buffer = []
for rank, (score, pid, passage) in enumerate(ranking):
is_relevant = self.qrels and int(pid in self.qrels[qid])
rank = rank+1 if is_ranked else -1
possibly_score = [score] if self.log_scores else []
f_buffer.append('\t'.join([str(x) for x in [qid, pid, rank] + possibly_score]) + "\n")
if self.g:
g_buffer.append('\t'.join([str(x) for x in [qid, pid, rank, is_relevant]]) + "\n")
if rank in print_positions:
prefix = "** " if is_relevant else ""
prefix += str(rank)
print("#> ( QID {} ) ".format(qid) + prefix + ") ", pid, ":", score, ' ', passage)
self.f.write(''.join(f_buffer))
if self.g:
self.g.write(''.join(g_buffer))