File size: 6,436 Bytes
828992f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import os
import time
import torch
import ujson
import numpy as np
import itertools
import threading
import queue
from colbert.modeling.inference import ModelInference
from colbert.evaluation.loaders import load_colbert
from colbert.utils.utils import print_message
from colbert.indexing.index_manager import IndexManager
class CollectionEncoder():
def __init__(self, args, process_idx, num_processes):
self.args = args
self.collection = args.collection
self.process_idx = process_idx
self.num_processes = num_processes
assert 0.5 <= args.chunksize <= 128.0
max_bytes_per_file = args.chunksize * (1024*1024*1024)
max_bytes_per_doc = (self.args.doc_maxlen * self.args.dim * 2.0)
# Determine subset sizes for output
minimum_subset_size = 10_000
maximum_subset_size = max_bytes_per_file / max_bytes_per_doc
maximum_subset_size = max(minimum_subset_size, maximum_subset_size)
self.possible_subset_sizes = [int(maximum_subset_size)]
self.print_main("#> Local args.bsize =", args.bsize)
self.print_main("#> args.index_root =", args.index_root)
self.print_main(f"#> self.possible_subset_sizes = {self.possible_subset_sizes}")
self._load_model()
self.indexmgr = IndexManager(args.dim)
self.iterator = self._initialize_iterator()
def _initialize_iterator(self):
return open(self.collection)
def _saver_thread(self):
for args in iter(self.saver_queue.get, None):
self._save_batch(*args)
def _load_model(self):
self.colbert, self.checkpoint = load_colbert(self.args, do_print=(self.process_idx == 0))
self.colbert = self.colbert.cuda()
self.colbert.eval()
self.inference = ModelInference(self.colbert, amp=self.args.amp)
def encode(self):
self.saver_queue = queue.Queue(maxsize=3)
thread = threading.Thread(target=self._saver_thread)
thread.start()
t0 = time.time()
local_docs_processed = 0
for batch_idx, (offset, lines, owner) in enumerate(self._batch_passages(self.iterator)):
if owner != self.process_idx:
continue
t1 = time.time()
batch = self._preprocess_batch(offset, lines)
embs, doclens = self._encode_batch(batch_idx, batch)
t2 = time.time()
self.saver_queue.put((batch_idx, embs, offset, doclens))
t3 = time.time()
local_docs_processed += len(lines)
overall_throughput = compute_throughput(local_docs_processed, t0, t3)
this_encoding_throughput = compute_throughput(len(lines), t1, t2)
this_saving_throughput = compute_throughput(len(lines), t2, t3)
self.print(f'#> Completed batch #{batch_idx} (starting at passage #{offset}) \t\t'
f'Passages/min: {overall_throughput} (overall), ',
f'{this_encoding_throughput} (this encoding), ',
f'{this_saving_throughput} (this saving)')
self.saver_queue.put(None)
self.print("#> Joining saver thread.")
thread.join()
def _batch_passages(self, fi):
"""
Must use the same seed across processes!
"""
np.random.seed(0)
offset = 0
for owner in itertools.cycle(range(self.num_processes)):
batch_size = np.random.choice(self.possible_subset_sizes)
L = [line for _, line in zip(range(batch_size), fi)]
if len(L) == 0:
break # EOF
yield (offset, L, owner)
offset += len(L)
if len(L) < batch_size:
break # EOF
self.print("[NOTE] Done with local share.")
return
def _preprocess_batch(self, offset, lines):
endpos = offset + len(lines)
batch = []
for line_idx, line in zip(range(offset, endpos), lines):
line_parts = line.strip().split('\t')
pid, passage, *other = line_parts
assert len(passage) >= 1
if len(other) >= 1:
title, *_ = other
passage = title + ' | ' + passage
batch.append(passage)
# assert pid == 'id' or int(pid) == line_idx
return batch
def _encode_batch(self, batch_idx, batch):
with torch.no_grad():
embs = self.inference.docFromText(batch, bsize=self.args.bsize, keep_dims=False)
assert type(embs) is list
assert len(embs) == len(batch)
local_doclens = [d.size(0) for d in embs]
embs = torch.cat(embs)
return embs, local_doclens
def _save_batch(self, batch_idx, embs, offset, doclens):
start_time = time.time()
output_path = os.path.join(self.args.index_path, "{}.pt".format(batch_idx))
output_sample_path = os.path.join(self.args.index_path, "{}.sample".format(batch_idx))
doclens_path = os.path.join(self.args.index_path, 'doclens.{}.json'.format(batch_idx))
# Save the embeddings.
self.indexmgr.save(embs, output_path)
self.indexmgr.save(embs[torch.randint(0, high=embs.size(0), size=(embs.size(0) // 20,))], output_sample_path)
# Save the doclens.
with open(doclens_path, 'w') as output_doclens:
ujson.dump(doclens, output_doclens)
throughput = compute_throughput(len(doclens), start_time, time.time())
self.print_main("#> Saved batch #{} to {} \t\t".format(batch_idx, output_path),
"Saving Throughput =", throughput, "passages per minute.\n")
def print(self, *args):
print_message("[" + str(self.process_idx) + "]", "\t\t", *args)
def print_main(self, *args):
if self.process_idx == 0:
self.print(*args)
def compute_throughput(size, t0, t1):
throughput = size / (t1 - t0) * 60
if throughput > 1000 * 1000:
throughput = throughput / (1000*1000)
throughput = round(throughput, 1)
return '{}M'.format(throughput)
throughput = throughput / (1000)
throughput = round(throughput, 1)
return '{}k'.format(throughput)
|