Spaces:
Runtime error
Runtime error
File size: 1,984 Bytes
58627fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import torch
from colbert.search.strided_tensor import StridedTensor
from .strided_tensor_core import _create_mask, _create_view
class CandidateGeneration:
def __init__(self, use_gpu=True):
self.use_gpu = use_gpu
def get_cells(self, Q, ncells):
scores = (self.codec.centroids @ Q.T)
if ncells == 1:
cells = scores.argmax(dim=0, keepdim=True).permute(1, 0)
else:
cells = scores.topk(ncells, dim=0, sorted=False).indices.permute(1, 0) # (32, ncells)
cells = cells.flatten().contiguous() # (32 * ncells,)
cells = cells.unique(sorted=False)
return cells, scores
def generate_candidate_eids(self, Q, ncells):
cells, scores = self.get_cells(Q, ncells)
eids, cell_lengths = self.ivf.lookup(cells) # eids = (packedlen,) lengths = (32 * ncells,)
eids = eids.long()
if self.use_gpu:
eids = eids.cuda()
return eids, scores
def generate_candidate_pids(self, Q, ncells):
cells, scores = self.get_cells(Q, ncells)
pids, cell_lengths = self.ivf.lookup(cells)
if self.use_gpu:
pids = pids.cuda()
return pids, scores
def generate_candidate_scores(self, Q, eids):
E = self.lookup_eids(eids)
if self.use_gpu:
E = E.cuda()
return (Q.unsqueeze(0) @ E.unsqueeze(2)).squeeze(-1).T
def generate_candidates(self, config, Q):
ncells = config.ncells
assert isinstance(self.ivf, StridedTensor)
Q = Q.squeeze(0)
if self.use_gpu:
Q = Q.cuda().half()
assert Q.dim() == 2
pids, centroid_scores = self.generate_candidate_pids(Q, ncells)
sorter = pids.sort()
pids = sorter.values
pids, pids_counts = torch.unique_consecutive(pids, return_counts=True)
if self.use_gpu:
pids, pids_counts = pids.cuda(), pids_counts.cuda()
return pids, centroid_scores
|