File size: 2,061 Bytes
828992f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import os
import time
import faiss
import random
import torch
import itertools
from colbert.utils.runs import Run
from multiprocessing import Pool
from colbert.modeling.inference import ModelInference
from colbert.evaluation.ranking_logger import RankingLogger
from colbert.utils.utils import print_message, batch
from colbert.ranking.rankers import Ranker
def retrieve(args):
inference = ModelInference(args.colbert, amp=args.amp)
ranker = Ranker(args, inference, faiss_depth=args.faiss_depth)
ranking_logger = RankingLogger(Run.path, qrels=None)
milliseconds = 0
with ranking_logger.context('ranking.tsv', also_save_annotations=False) as rlogger:
queries = args.queries
qids_in_order = list(queries.keys())
for qoffset, qbatch in batch(qids_in_order, 100, provide_offset=True):
qbatch_text = [queries[qid] for qid in qbatch]
rankings = []
for query_idx, q in enumerate(qbatch_text):
torch.cuda.synchronize('cuda:0')
s = time.time()
Q = ranker.encode([q])
pids, scores = ranker.rank(Q)
torch.cuda.synchronize()
milliseconds += (time.time() - s) * 1000.0
if len(pids):
print(qoffset+query_idx, q, len(scores), len(pids), scores[0], pids[0],
milliseconds / (qoffset+query_idx+1), 'ms')
rankings.append(zip(pids, scores))
for query_idx, (qid, ranking) in enumerate(zip(qbatch, rankings)):
query_idx = qoffset + query_idx
if query_idx % 100 == 0:
print_message(f"#> Logging query #{query_idx} (qid {qid}) now...")
ranking = [(score, pid, None) for pid, score in itertools.islice(ranking, args.depth)]
rlogger.log(qid, ranking, is_ranked=True)
print('\n\n')
print(ranking_logger.filename)
print("#> Done.")
print('\n\n')
|