import os
import time
import torch
import ujson
import numpy as np
import itertools
import threading
import queue
from colbert.modeling.inference import ModelInference
from colbert.evaluation.loaders import load_colbert
from colbert.utils.utils import print_message
from colbert.indexing.index_manager import IndexManager
class CollectionEncoder():
def __init__(self, args, process_idx, num_processes):
self.args = args
self.collection = args.collection
self.process_idx = process_idx
self.num_processes = num_processes
assert 0.5 <= args.chunksize <= 128.0
max_bytes_per_file = args.chunksize * (1024*1024*1024)
max_bytes_per_doc = (self.args.doc_maxlen * self.args.dim * 2.0)
minimum_subset_size = 10_000
maximum_subset_size = max_bytes_per_file / max_bytes_per_doc
maximum_subset_size = max(minimum_subset_size, maximum_subset_size)
self.possible_subset_sizes = [int(maximum_subset_size)]
self.print_main("#> Local args.bsize =", args.bsize)
self.print_main("#> args.index_root =", args.index_root)
self.print_main(f"#> self.possible_subset_sizes = {self.possible_subset_sizes}")
self.indexmgr = IndexManager(args.dim)
self.iterator = self._initialize_iterator()
def _initialize_iterator(self):
return open(self.collection)
def _saver_thread(self):
for args in iter(self.saver_queue.get, None):
def _load_model(self):
self.colbert, self.checkpoint = load_colbert(self.args, do_print=(self.process_idx == 0))
self.colbert = self.colbert.cuda()
self.inference = ModelInference(self.colbert, amp=self.args.amp)
def encode(self):
self.saver_queue = queue.Queue(maxsize=3)
thread = threading.Thread(target=self._saver_thread)
t0 = time.time()
local_docs_processed = 0
for batch_idx, (offset, lines, owner) in enumerate(self._batch_passages(self.iterator)):
if owner != self.process_idx:
t1 = time.time()
batch = self._preprocess_batch(offset, lines)
embs, doclens = self._encode_batch(batch_idx, batch)
t2 = time.time()
self.saver_queue.put((batch_idx, embs, offset, doclens))
t3 = time.time()
local_docs_processed += len(lines)
overall_throughput = compute_throughput(local_docs_processed, t0, t3)
this_encoding_throughput = compute_throughput(len(lines), t1, t2)
this_saving_throughput = compute_throughput(len(lines), t2, t3)
self.print(f'#> Completed batch #{batch_idx} (starting at passage #{offset}) \t\t'
f'Passages/min: {overall_throughput} (overall), ',
f'{this_encoding_throughput} (this encoding), ',
f'{this_saving_throughput} (this saving)')
self.print("#> Joining saver thread.")
def _batch_passages(self, fi):
Must use the same seed across processes!
offset = 0
for owner in itertools.cycle(range(self.num_processes)):
batch_size = np.random.choice(self.possible_subset_sizes)
L = [line for _, line in zip(range(batch_size), fi)]
if len(L) == 0:
yield (offset, L, owner)
offset += len(L)
if len(L) < batch_size:
self.print("[NOTE] Done with local share.")
def _preprocess_batch(self, offset, lines):
endpos = offset + len(lines)
batch = []
for line_idx, line in zip(range(offset, endpos), lines):
line_parts = line.strip().split('\t')
pid, passage, *other = line_parts
assert len(passage) >= 1
if len(other) >= 1:
title, *_ = other
passage = title + ' | ' + passage
return batch
def _encode_batch(self, batch_idx, batch):
with torch.no_grad():
embs = self.inference.docFromText(batch, bsize=self.args.bsize, keep_dims=False)
assert type(embs) is list
assert len(embs) == len(batch)
local_doclens = [d.size(0) for d in embs]
embs = torch.cat(embs)
return embs, local_doclens
def _save_batch(self, batch_idx, embs, offset, doclens):
start_time = time.time()
output_path = os.path.join(self.args.index_path, "{}.pt".format(batch_idx))
output_sample_path = os.path.join(self.args.index_path, "{}.sample".format(batch_idx))
doclens_path = os.path.join(self.args.index_path, 'doclens.{}.json'.format(batch_idx))
self.indexmgr.save(embs, output_path)
self.indexmgr.save(embs[torch.randint(0, high=embs.size(0), size=(embs.size(0) // 20,))], output_sample_path)
with open(doclens_path, 'w') as output_doclens:
ujson.dump(doclens, output_doclens)
throughput = compute_throughput(len(doclens), start_time, time.time())
self.print_main("#> Saved batch #{} to {} \t\t".format(batch_idx, output_path),
"Saving Throughput =", throughput, "passages per minute.\n")
def print(self, *args):
print_message("[" + str(self.process_idx) + "]", "\t\t", *args)
def print_main(self, *args):
if self.process_idx == 0:
def compute_throughput(size, t0, t1):
throughput = size / (t1 - t0) * 60
if throughput > 1000 * 1000:
throughput = throughput / (1000*1000)
throughput = round(throughput, 1)
return '{}M'.format(throughput)
throughput = throughput / (1000)
throughput = round(throughput, 1)
return '{}k'.format(throughput)