Spaces:
Runtime error
Runtime error
import torch | |
from colbert.search.strided_tensor import StridedTensor | |
from .strided_tensor_core import _create_mask, _create_view | |
class CandidateGeneration: | |
def __init__(self, use_gpu=True): | |
self.use_gpu = use_gpu | |
def get_cells(self, Q, ncells): | |
scores = (self.codec.centroids @ Q.T) | |
if ncells == 1: | |
cells = scores.argmax(dim=0, keepdim=True).permute(1, 0) | |
else: | |
cells = scores.topk(ncells, dim=0, sorted=False).indices.permute(1, 0) # (32, ncells) | |
cells = cells.flatten().contiguous() # (32 * ncells,) | |
cells = cells.unique(sorted=False) | |
return cells, scores | |
def generate_candidate_eids(self, Q, ncells): | |
cells, scores = self.get_cells(Q, ncells) | |
eids, cell_lengths = self.ivf.lookup(cells) # eids = (packedlen,) lengths = (32 * ncells,) | |
eids = eids.long() | |
if self.use_gpu: | |
eids = eids.cuda() | |
return eids, scores | |
def generate_candidate_pids(self, Q, ncells): | |
cells, scores = self.get_cells(Q, ncells) | |
pids, cell_lengths = self.ivf.lookup(cells) | |
if self.use_gpu: | |
pids = pids.cuda() | |
return pids, scores | |
def generate_candidate_scores(self, Q, eids): | |
E = self.lookup_eids(eids) | |
if self.use_gpu: | |
E = E.cuda() | |
return (Q.unsqueeze(0) @ E.unsqueeze(2)).squeeze(-1).T | |
def generate_candidates(self, config, Q): | |
ncells = config.ncells | |
assert isinstance(self.ivf, StridedTensor) | |
Q = Q.squeeze(0) | |
if self.use_gpu: | |
Q = Q.cuda().half() | |
assert Q.dim() == 2 | |
pids, centroid_scores = self.generate_candidate_pids(Q, ncells) | |
sorter = pids.sort() | |
pids = sorter.values | |
pids, pids_counts = torch.unique_consecutive(pids, return_counts=True) | |
if self.use_gpu: | |
pids, pids_counts = pids.cuda(), pids_counts.cuda() | |
return pids, centroid_scores | |