File size: 647 Bytes
828992f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
import os
def slow_rerank(args, query, pids, passages):
colbert = args.colbert
inference = args.inference
Q = inference.queryFromText([query])
D_ = inference.docFromText(passages, bsize=args.bsize)
scores = colbert.score(Q, D_).cpu()
scores = scores.sort(descending=True)
ranked = scores.indices.tolist()
ranked_scores = scores.values.tolist()
ranked_pids = [pids[position] for position in ranked]
ranked_passages = [passages[position] for position in ranked]
assert len(ranked_pids) == len(set(ranked_pids))
return list(zip(ranked_scores, ranked_pids, ranked_passages))
|