vjeronymo2's picture
Adding model and checkpoint
828992f
import os
import time
import torch
import ujson
import numpy as np
import itertools
import threading
import queue
from colbert.modeling.inference import ModelInference
from colbert.evaluation.loaders import load_colbert
from colbert.utils.utils import print_message
from colbert.indexing.index_manager import IndexManager
class CollectionEncoder():
def __init__(self, args, process_idx, num_processes):
self.args = args
self.collection = args.collection
self.process_idx = process_idx
self.num_processes = num_processes
assert 0.5 <= args.chunksize <= 128.0
max_bytes_per_file = args.chunksize * (1024*1024*1024)
max_bytes_per_doc = (self.args.doc_maxlen * self.args.dim * 2.0)
# Determine subset sizes for output
minimum_subset_size = 10_000
maximum_subset_size = max_bytes_per_file / max_bytes_per_doc
maximum_subset_size = max(minimum_subset_size, maximum_subset_size)
self.possible_subset_sizes = [int(maximum_subset_size)]
self.print_main("#> Local args.bsize =", args.bsize)
self.print_main("#> args.index_root =", args.index_root)
self.print_main(f"#> self.possible_subset_sizes = {self.possible_subset_sizes}")
self._load_model()
self.indexmgr = IndexManager(args.dim)
self.iterator = self._initialize_iterator()
def _initialize_iterator(self):
return open(self.collection)
def _saver_thread(self):
for args in iter(self.saver_queue.get, None):
self._save_batch(*args)
def _load_model(self):
self.colbert, self.checkpoint = load_colbert(self.args, do_print=(self.process_idx == 0))
self.colbert = self.colbert.cuda()
self.colbert.eval()
self.inference = ModelInference(self.colbert, amp=self.args.amp)
def encode(self):
self.saver_queue = queue.Queue(maxsize=3)
thread = threading.Thread(target=self._saver_thread)
thread.start()
t0 = time.time()
local_docs_processed = 0
for batch_idx, (offset, lines, owner) in enumerate(self._batch_passages(self.iterator)):
if owner != self.process_idx:
continue
t1 = time.time()
batch = self._preprocess_batch(offset, lines)
embs, doclens = self._encode_batch(batch_idx, batch)
t2 = time.time()
self.saver_queue.put((batch_idx, embs, offset, doclens))
t3 = time.time()
local_docs_processed += len(lines)
overall_throughput = compute_throughput(local_docs_processed, t0, t3)
this_encoding_throughput = compute_throughput(len(lines), t1, t2)
this_saving_throughput = compute_throughput(len(lines), t2, t3)
self.print(f'#> Completed batch #{batch_idx} (starting at passage #{offset}) \t\t'
f'Passages/min: {overall_throughput} (overall), ',
f'{this_encoding_throughput} (this encoding), ',
f'{this_saving_throughput} (this saving)')
self.saver_queue.put(None)
self.print("#> Joining saver thread.")
thread.join()
def _batch_passages(self, fi):
"""
Must use the same seed across processes!
"""
np.random.seed(0)
offset = 0
for owner in itertools.cycle(range(self.num_processes)):
batch_size = np.random.choice(self.possible_subset_sizes)
L = [line for _, line in zip(range(batch_size), fi)]
if len(L) == 0:
break # EOF
yield (offset, L, owner)
offset += len(L)
if len(L) < batch_size:
break # EOF
self.print("[NOTE] Done with local share.")
return
def _preprocess_batch(self, offset, lines):
endpos = offset + len(lines)
batch = []
for line_idx, line in zip(range(offset, endpos), lines):
line_parts = line.strip().split('\t')
pid, passage, *other = line_parts
assert len(passage) >= 1
if len(other) >= 1:
title, *_ = other
passage = title + ' | ' + passage
batch.append(passage)
# assert pid == 'id' or int(pid) == line_idx
return batch
def _encode_batch(self, batch_idx, batch):
with torch.no_grad():
embs = self.inference.docFromText(batch, bsize=self.args.bsize, keep_dims=False)
assert type(embs) is list
assert len(embs) == len(batch)
local_doclens = [d.size(0) for d in embs]
embs = torch.cat(embs)
return embs, local_doclens
def _save_batch(self, batch_idx, embs, offset, doclens):
start_time = time.time()
output_path = os.path.join(self.args.index_path, "{}.pt".format(batch_idx))
output_sample_path = os.path.join(self.args.index_path, "{}.sample".format(batch_idx))
doclens_path = os.path.join(self.args.index_path, 'doclens.{}.json'.format(batch_idx))
# Save the embeddings.
self.indexmgr.save(embs, output_path)
self.indexmgr.save(embs[torch.randint(0, high=embs.size(0), size=(embs.size(0) // 20,))], output_sample_path)
# Save the doclens.
with open(doclens_path, 'w') as output_doclens:
ujson.dump(doclens, output_doclens)
throughput = compute_throughput(len(doclens), start_time, time.time())
self.print_main("#> Saved batch #{} to {} \t\t".format(batch_idx, output_path),
"Saving Throughput =", throughput, "passages per minute.\n")
def print(self, *args):
print_message("[" + str(self.process_idx) + "]", "\t\t", *args)
def print_main(self, *args):
if self.process_idx == 0:
self.print(*args)
def compute_throughput(size, t0, t1):
throughput = size / (t1 - t0) * 60
if throughput > 1000 * 1000:
throughput = throughput / (1000*1000)
throughput = round(throughput, 1)
return '{}M'.format(throughput)
throughput = throughput / (1000)
throughput = round(throughput, 1)
return '{}k'.format(throughput)