mColBERT / colbert /ranking /rankers.py
vjeronymo2's picture
Adding model and checkpoint
828992f
import torch
from functools import partial
from colbert.ranking.index_part import IndexPart
from colbert.ranking.faiss_index import FaissIndex
from colbert.utils.utils import flatten, zipstar
class Ranker():
def __init__(self, args, inference, faiss_depth=1024):
self.inference = inference
self.faiss_depth = faiss_depth
if faiss_depth is not None:
self.faiss_index = FaissIndex(args.index_path, args.faiss_index_path, args.nprobe, part_range=args.part_range)
self.retrieve = partial(self.faiss_index.retrieve, self.faiss_depth)
self.index = IndexPart(args.index_path, dim=inference.colbert.dim, part_range=args.part_range, verbose=True)
def encode(self, queries):
assert type(queries) in [list, tuple], type(queries)
Q = self.inference.queryFromText(queries, bsize=512 if len(queries) > 512 else None)
return Q
def rank(self, Q, pids=None):
pids = self.retrieve(Q, verbose=False)[0] if pids is None else pids
assert type(pids) in [list, tuple], type(pids)
assert Q.size(0) == 1, (len(pids), Q.size())
assert all(type(pid) is int for pid in pids)
scores = []
if len(pids) > 0:
Q = Q.permute(0, 2, 1)
scores = self.index.rank(Q, pids)
scores_sorter = torch.tensor(scores).sort(descending=True)
pids, scores = torch.tensor(pids)[scores_sorter.indices].tolist(), scores_sorter.values.tolist()
return pids, scores