Spaces:
Runtime error
Runtime error
File size: 1,823 Bytes
58627fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import torch
from colbert.infra.run import Run
from colbert.utils.utils import print_message, batch
class CollectionEncoder():
def __init__(self, config, checkpoint):
self.config = config
self.checkpoint = checkpoint
self.use_gpu = self.config.total_visible_gpus > 0
def encode_passages(self, passages):
Run().print(f"#> Encoding {len(passages)} passages..")
if len(passages) == 0:
return None, None
with torch.inference_mode():
embs, doclens = [], []
# Batch here to avoid OOM from storing intermediate embeddings on GPU.
# Storing on the GPU helps with speed of masking, etc.
# But ideally this batching happens internally inside docFromText.
for passages_batch in batch(passages, self.config.bsize * 50):
embs_, doclens_ = self.checkpoint.docFromText(passages_batch, bsize=self.config.bsize,
keep_dims='flatten', showprogress=(not self.use_gpu))
embs.append(embs_)
doclens.extend(doclens_)
embs = torch.cat(embs)
# embs, doclens = self.checkpoint.docFromText(passages, bsize=self.config.bsize,
# keep_dims='flatten', showprogress=(self.config.rank < 1))
# with torch.inference_mode():
# embs = self.checkpoint.docFromText(passages, bsize=self.config.bsize,
# keep_dims=False, showprogress=(self.config.rank < 1))
# assert type(embs) is list
# assert len(embs) == len(passages)
# doclens = [d.size(0) for d in embs]
# embs = torch.cat(embs)
return embs, doclens
|