|
import torch
|
|
|
|
from functools import partial
|
|
|
|
from colbert.ranking.index_part import IndexPart
|
|
from colbert.ranking.faiss_index import FaissIndex
|
|
from colbert.utils.utils import flatten, zipstar
|
|
|
|
|
|
class Ranker():
|
|
def __init__(self, args, inference, faiss_depth=1024):
|
|
self.inference = inference
|
|
self.faiss_depth = faiss_depth
|
|
|
|
if faiss_depth is not None:
|
|
self.faiss_index = FaissIndex(args.index_path, args.faiss_index_path, args.nprobe, part_range=args.part_range)
|
|
self.retrieve = partial(self.faiss_index.retrieve, self.faiss_depth)
|
|
|
|
self.index = IndexPart(args.index_path, dim=inference.colbert.dim, part_range=args.part_range, verbose=True)
|
|
|
|
def encode(self, queries):
|
|
assert type(queries) in [list, tuple], type(queries)
|
|
|
|
Q = self.inference.queryFromText(queries, bsize=512 if len(queries) > 512 else None)
|
|
|
|
return Q
|
|
|
|
def rank(self, Q, pids=None):
|
|
pids = self.retrieve(Q, verbose=False)[0] if pids is None else pids
|
|
|
|
assert type(pids) in [list, tuple], type(pids)
|
|
assert Q.size(0) == 1, (len(pids), Q.size())
|
|
assert all(type(pid) is int for pid in pids)
|
|
|
|
scores = []
|
|
if len(pids) > 0:
|
|
Q = Q.permute(0, 2, 1)
|
|
scores = self.index.rank(Q, pids)
|
|
|
|
scores_sorter = torch.tensor(scores).sort(descending=True)
|
|
pids, scores = torch.tensor(pids)[scores_sorter.indices].tolist(), scores_sorter.values.tolist()
|
|
|
|
return pids, scores
|
|
|