ir_chinese_medqa / colbert /indexing /collection_encoder.py
欧卫
'add_app_files'
58627fa
import torch
from colbert.infra.run import Run
from colbert.utils.utils import print_message, batch
class CollectionEncoder():
def __init__(self, config, checkpoint):
self.config = config
self.checkpoint = checkpoint
self.use_gpu = self.config.total_visible_gpus > 0
def encode_passages(self, passages):
Run().print(f"#> Encoding {len(passages)} passages..")
if len(passages) == 0:
return None, None
with torch.inference_mode():
embs, doclens = [], []
# Batch here to avoid OOM from storing intermediate embeddings on GPU.
# Storing on the GPU helps with speed of masking, etc.
# But ideally this batching happens internally inside docFromText.
for passages_batch in batch(passages, self.config.bsize * 50):
embs_, doclens_ = self.checkpoint.docFromText(passages_batch, bsize=self.config.bsize,
keep_dims='flatten', showprogress=(not self.use_gpu))
embs.append(embs_)
doclens.extend(doclens_)
embs = torch.cat(embs)
# embs, doclens = self.checkpoint.docFromText(passages, bsize=self.config.bsize,
# keep_dims='flatten', showprogress=(self.config.rank < 1))
# with torch.inference_mode():
# embs = self.checkpoint.docFromText(passages, bsize=self.config.bsize,
# keep_dims=False, showprogress=(self.config.rank < 1))
# assert type(embs) is list
# assert len(embs) == len(passages)
# doclens = [d.size(0) for d in embs]
# embs = torch.cat(embs)
return embs, doclens