Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import time | |
import pickle | |
import os | |
import logging | |
from multiprocessing.pool import ThreadPool | |
import threading | |
import _thread | |
from queue import Queue | |
import traceback | |
import datetime | |
import numpy as np | |
import faiss | |
from faiss.contrib.inspect_tools import get_invlist | |
class BigBatchSearcher: | |
""" | |
Object that manages all the data related to the computation | |
except the actual within-bucket matching and the organization of the | |
computation (parallel or not) | |
""" | |
def __init__( | |
self, | |
index, xq, k, | |
verbose=0, | |
use_float16=False): | |
# verbosity | |
self.verbose = verbose | |
self.tictoc = [] | |
self.xq = xq | |
self.index = index | |
self.use_float16 = use_float16 | |
keep_max = faiss.is_similarity_metric(index.metric_type) | |
self.rh = faiss.ResultHeap(len(xq), k, keep_max=keep_max) | |
self.t_accu = [0] * 6 | |
self.t_display = self.t0 = time.time() | |
def start_t_accu(self): | |
self.t_accu_t0 = time.time() | |
def stop_t_accu(self, n): | |
self.t_accu[n] += time.time() - self.t_accu_t0 | |
def tic(self, name): | |
self.tictoc = (name, time.time()) | |
if self.verbose > 0: | |
print(name, end="\r", flush=True) | |
def toc(self): | |
name, t0 = self.tictoc | |
dt = time.time() - t0 | |
if self.verbose > 0: | |
print(f"{name}: {dt:.3f} s") | |
return dt | |
def report(self, l): | |
if self.verbose == 1 or ( | |
self.verbose == 2 and ( | |
l > 1000 and time.time() < self.t_display + 1.0 | |
) | |
): | |
return | |
t = time.time() - self.t0 | |
print( | |
f"[{t:.1f} s] list {l}/{self.index.nlist} " | |
f"times prep q {self.t_accu[0]:.3f} prep b {self.t_accu[1]:.3f} " | |
f"comp {self.t_accu[2]:.3f} res {self.t_accu[3]:.3f} " | |
f"wait in {self.t_accu[4]:.3f} " | |
f"wait out {self.t_accu[5]:.3f} " | |
f"eta {datetime.timedelta(seconds=t*self.index.nlist/(l+1)-t)} " | |
f"mem {faiss.get_mem_usage_kb()}", | |
end="\r" if self.verbose <= 2 else "\n", | |
flush=True, | |
) | |
self.t_display = time.time() | |
def coarse_quantization(self): | |
self.tic("coarse quantization") | |
bs = 65536 | |
nq = len(self.xq) | |
q_assign = np.empty((nq, self.index.nprobe), dtype='int32') | |
for i0 in range(0, nq, bs): | |
i1 = min(nq, i0 + bs) | |
q_dis_i, q_assign_i = self.index.quantizer.search( | |
self.xq[i0:i1], self.index.nprobe) | |
# q_dis[i0:i1] = q_dis_i | |
q_assign[i0:i1] = q_assign_i | |
self.toc() | |
self.q_assign = q_assign | |
def reorder_assign(self): | |
self.tic("bucket sort") | |
q_assign = self.q_assign | |
q_assign += 1 # move -1 -> 0 | |
self.bucket_lims = faiss.matrix_bucket_sort_inplace( | |
self.q_assign, nbucket=self.index.nlist + 1, nt=16) | |
self.query_ids = self.q_assign.ravel() | |
if self.verbose > 0: | |
print(' number of -1s:', self.bucket_lims[1]) | |
self.bucket_lims = self.bucket_lims[1:] # shift back to ignore -1s | |
del self.q_assign # inplace so let's forget about the old version... | |
self.toc() | |
def prepare_bucket(self, l): | |
""" prepare the queries and database items for bucket l""" | |
t0 = time.time() | |
index = self.index | |
# prepare queries | |
i0, i1 = self.bucket_lims[l], self.bucket_lims[l + 1] | |
q_subset = self.query_ids[i0:i1] | |
xq_l = self.xq[q_subset] | |
if self.by_residual: | |
xq_l = xq_l - index.quantizer.reconstruct(l) | |
t1 = time.time() | |
# prepare database side | |
list_ids, xb_l = get_invlist(index.invlists, l) | |
if self.decode_func is None: | |
xb_l = xb_l.ravel() | |
else: | |
xb_l = self.decode_func(xb_l) | |
if self.use_float16: | |
xb_l = xb_l.astype('float16') | |
xq_l = xq_l.astype('float16') | |
t2 = time.time() | |
self.t_accu[0] += t1 - t0 | |
self.t_accu[1] += t2 - t1 | |
return q_subset, xq_l, list_ids, xb_l | |
def add_results_to_heap(self, q_subset, D, list_ids, I): | |
"""add the bucket results to the heap structure""" | |
if D is None: | |
return | |
t0 = time.time() | |
if I is None: | |
I = list_ids | |
else: | |
I = list_ids[I] | |
self.rh.add_result_subset(q_subset, D, I) | |
self.t_accu[3] += time.time() - t0 | |
def sizes_in_checkpoint(self): | |
return (self.xq.shape, self.index.nprobe, self.index.nlist) | |
def write_checkpoint(self, fname, completed): | |
# write to temp file then move to final file | |
tmpname = fname + ".tmp" | |
with open(tmpname, "wb") as f: | |
pickle.dump( | |
{ | |
"sizes": self.sizes_in_checkpoint(), | |
"completed": completed, | |
"rh": (self.rh.D, self.rh.I), | |
}, f, -1) | |
os.replace(tmpname, fname) | |
def read_checkpoint(self, fname): | |
with open(fname, "rb") as f: | |
ckp = pickle.load(f) | |
assert ckp["sizes"] == self.sizes_in_checkpoint() | |
self.rh.D[:] = ckp["rh"][0] | |
self.rh.I[:] = ckp["rh"][1] | |
return ckp["completed"] | |
class BlockComputer: | |
""" computation within one bucket """ | |
def __init__( | |
self, | |
index, | |
method="knn_function", | |
pairwise_distances=faiss.pairwise_distances, | |
knn=faiss.knn): | |
self.index = index | |
if index.__class__ == faiss.IndexIVFFlat: | |
index_help = faiss.IndexFlat(index.d, index.metric_type) | |
decode_func = lambda x: x.view("float32") | |
by_residual = False | |
elif index.__class__ == faiss.IndexIVFPQ: | |
index_help = faiss.IndexPQ( | |
index.d, index.pq.M, index.pq.nbits, index.metric_type) | |
index_help.pq = index.pq | |
decode_func = index_help.pq.decode | |
index_help.is_trained = True | |
by_residual = index.by_residual | |
elif index.__class__ == faiss.IndexIVFScalarQuantizer: | |
index_help = faiss.IndexScalarQuantizer( | |
index.d, index.sq.qtype, index.metric_type) | |
index_help.sq = index.sq | |
decode_func = index_help.sq.decode | |
index_help.is_trained = True | |
by_residual = index.by_residual | |
else: | |
raise RuntimeError(f"index type {index.__class__} not supported") | |
self.index_help = index_help | |
self.decode_func = None if method == "index" else decode_func | |
self.by_residual = by_residual | |
self.method = method | |
self.pairwise_distances = pairwise_distances | |
self.knn = knn | |
def block_search(self, xq_l, xb_l, list_ids, k, **extra_args): | |
metric_type = self.index.metric_type | |
if xq_l.size == 0 or xb_l.size == 0: | |
D = I = None | |
elif self.method == "index": | |
faiss.copy_array_to_vector(xb_l, self.index_help.codes) | |
self.index_help.ntotal = len(list_ids) | |
D, I = self.index_help.search(xq_l, k) | |
elif self.method == "pairwise_distances": | |
# TODO implement blockwise to avoid mem blowup | |
D = self.pairwise_distances(xq_l, xb_l, metric=metric_type) | |
I = None | |
elif self.method == "knn_function": | |
D, I = self.knn(xq_l, xb_l, k, metric=metric_type, **extra_args) | |
return D, I | |
def big_batch_search( | |
index, xq, k, | |
method="knn_function", | |
pairwise_distances=faiss.pairwise_distances, | |
knn=faiss.knn, | |
verbose=0, | |
threaded=0, | |
use_float16=False, | |
prefetch_threads=1, | |
computation_threads=1, | |
q_assign=None, | |
checkpoint=None, | |
checkpoint_freq=7200, | |
start_list=0, | |
end_list=None, | |
crash_at=-1 | |
): | |
""" | |
Search queries xq in the IVF index, with a search function that collects | |
batches of query vectors per inverted list. This can be faster than the | |
regular search indexes. | |
Supports IVFFlat, IVFPQ and IVFScalarQuantizer. | |
Supports three computation methods: | |
method = "index": | |
build a flat index and populate it separately for each index | |
method = "pairwise_distances": | |
decompress codes and compute all pairwise distances for the queries | |
and index and add result to heap | |
method = "knn_function": | |
decompress codes and compute knn results for the queries | |
threaded=0: sequential execution | |
threaded=1: prefetch next bucket while computing the current one | |
threaded=2: prefetch prefetch_threads buckets at a time. | |
compute_threads>1: the knn function will get an additional thread_no that | |
tells which worker should handle this. | |
In threaded mode, the computation is tiled with the bucket perparation and | |
the writeback of results (useful to maximize GPU utilization). | |
use_float16: convert all matrices to float16 (faster for GPU gemm) | |
q_assign: override coarse assignment, should be a matrix of size nq * nprobe | |
checkpointing (only for threaded > 1): | |
checkpoint: file where the checkpoints are stored | |
checkpoint_freq: when to perform checkpoinging. Should be a multiple of threaded | |
start_list, end_list: process only a subset of invlists | |
""" | |
nprobe = index.nprobe | |
assert method in ("index", "pairwise_distances", "knn_function") | |
mem_queries = xq.nbytes | |
mem_assign = len(xq) * nprobe * np.dtype('int32').itemsize | |
mem_res = len(xq) * k * ( | |
np.dtype('int64').itemsize | |
+ np.dtype('float32').itemsize | |
) | |
mem_tot = mem_queries + mem_assign + mem_res | |
if verbose > 0: | |
logging.info( | |
f"memory: queries {mem_queries} assign {mem_assign} " | |
f"result {mem_res} total {mem_tot} = {mem_tot / (1<<30):.3f} GiB" | |
) | |
bbs = BigBatchSearcher( | |
index, xq, k, | |
verbose=verbose, | |
use_float16=use_float16 | |
) | |
comp = BlockComputer( | |
index, | |
method=method, | |
pairwise_distances=pairwise_distances, | |
knn=knn | |
) | |
bbs.decode_func = comp.decode_func | |
bbs.by_residual = comp.by_residual | |
if q_assign is None: | |
bbs.coarse_quantization() | |
else: | |
bbs.q_assign = q_assign | |
bbs.reorder_assign() | |
if end_list is None: | |
end_list = index.nlist | |
completed = set() | |
if checkpoint is not None: | |
assert (start_list, end_list) == (0, index.nlist) | |
if os.path.exists(checkpoint): | |
logging.info(f"recovering checkpoint: {checkpoint}") | |
completed = bbs.read_checkpoint(checkpoint) | |
logging.info(f" already completed: {len(completed)}") | |
else: | |
logging.info("no checkpoint: starting from scratch") | |
if threaded == 0: | |
# simple sequential version | |
for l in range(start_list, end_list): | |
bbs.report(l) | |
q_subset, xq_l, list_ids, xb_l = bbs.prepare_bucket(l) | |
t0i = time.time() | |
D, I = comp.block_search(xq_l, xb_l, list_ids, k) | |
bbs.t_accu[2] += time.time() - t0i | |
bbs.add_results_to_heap(q_subset, D, list_ids, I) | |
elif threaded == 1: | |
# parallel version with granularity 1 | |
def add_results_and_prefetch(to_add, l): | |
""" perform the addition for the previous bucket and | |
prefetch the next (if applicable) """ | |
if to_add is not None: | |
bbs.add_results_to_heap(*to_add) | |
if l < index.nlist: | |
return bbs.prepare_bucket(l) | |
prefetched_bucket = bbs.prepare_bucket(start_list) | |
to_add = None | |
pool = ThreadPool(1) | |
for l in range(start_list, end_list): | |
bbs.report(l) | |
prefetched_bucket_a = pool.apply_async( | |
add_results_and_prefetch, (to_add, l + 1)) | |
q_subset, xq_l, list_ids, xb_l = prefetched_bucket | |
bbs.start_t_accu() | |
D, I = comp.block_search(xq_l, xb_l, list_ids, k) | |
bbs.stop_t_accu(2) | |
to_add = q_subset, D, list_ids, I | |
bbs.start_t_accu() | |
prefetched_bucket = prefetched_bucket_a.get() | |
bbs.stop_t_accu(4) | |
bbs.add_results_to_heap(*to_add) | |
pool.close() | |
else: | |
def task_manager_thread( | |
task, | |
pool_size, | |
start_task, | |
end_task, | |
completed, | |
output_queue, | |
input_queue, | |
): | |
try: | |
with ThreadPool(pool_size) as pool: | |
res = [pool.apply_async( | |
task, | |
args=(i, output_queue, input_queue)) | |
for i in range(start_task, end_task) | |
if i not in completed] | |
for r in res: | |
r.get() | |
pool.close() | |
pool.join() | |
output_queue.put(None) | |
except: | |
traceback.print_exc() | |
_thread.interrupt_main() | |
raise | |
def task_manager(*args): | |
task_manager = threading.Thread( | |
target=task_manager_thread, | |
args=args, | |
) | |
task_manager.daemon = True | |
task_manager.start() | |
return task_manager | |
def prepare_task(task_id, output_queue, input_queue=None): | |
try: | |
logging.info(f"Prepare start: {task_id}") | |
q_subset, xq_l, list_ids, xb_l = bbs.prepare_bucket(task_id) | |
output_queue.put((task_id, q_subset, xq_l, list_ids, xb_l)) | |
logging.info(f"Prepare end: {task_id}") | |
except: | |
traceback.print_exc() | |
_thread.interrupt_main() | |
raise | |
def compute_task(task_id, output_queue, input_queue): | |
try: | |
logging.info(f"Compute start: {task_id}") | |
t_wait_out = 0 | |
while True: | |
t0 = time.time() | |
logging.info(f'Compute input: task {task_id}') | |
input_value = input_queue.get() | |
t_wait_in = time.time() - t0 | |
if input_value is None: | |
# signal for other compute tasks | |
input_queue.put(None) | |
break | |
centroid, q_subset, xq_l, list_ids, xb_l = input_value | |
logging.info(f'Compute work: task {task_id}, centroid {centroid}') | |
t0 = time.time() | |
if computation_threads > 1: | |
D, I = comp.block_search( | |
xq_l, xb_l, list_ids, k, thread_id=task_id | |
) | |
else: | |
D, I = comp.block_search(xq_l, xb_l, list_ids, k) | |
t_compute = time.time() - t0 | |
logging.info(f'Compute output: task {task_id}, centroid {centroid}') | |
t0 = time.time() | |
output_queue.put( | |
(centroid, t_wait_in, t_wait_out, t_compute, q_subset, D, list_ids, I) | |
) | |
t_wait_out = time.time() - t0 | |
logging.info(f"Compute end: {task_id}") | |
except: | |
traceback.print_exc() | |
_thread.interrupt_main() | |
raise | |
prepare_to_compute_queue = Queue(2) | |
compute_to_main_queue = Queue(2) | |
compute_task_manager = task_manager( | |
compute_task, | |
computation_threads, | |
0, | |
computation_threads, | |
set(), | |
compute_to_main_queue, | |
prepare_to_compute_queue, | |
) | |
prepare_task_manager = task_manager( | |
prepare_task, | |
prefetch_threads, | |
start_list, | |
end_list, | |
completed, | |
prepare_to_compute_queue, | |
None, | |
) | |
t_checkpoint = time.time() | |
while True: | |
logging.info("Waiting for result") | |
value = compute_to_main_queue.get() | |
if not value: | |
break | |
centroid, t_wait_in, t_wait_out, t_compute, q_subset, D, list_ids, I = value | |
# to test checkpointing | |
if centroid == crash_at: | |
1 / 0 | |
bbs.t_accu[2] += t_compute | |
bbs.t_accu[4] += t_wait_in | |
bbs.t_accu[5] += t_wait_out | |
logging.info(f"Adding to heap start: centroid {centroid}") | |
bbs.add_results_to_heap(q_subset, D, list_ids, I) | |
logging.info(f"Adding to heap end: centroid {centroid}") | |
completed.add(centroid) | |
bbs.report(centroid) | |
if checkpoint is not None: | |
if time.time() - t_checkpoint > checkpoint_freq: | |
logging.info("writing checkpoint") | |
bbs.write_checkpoint(checkpoint, completed) | |
t_checkpoint = time.time() | |
prepare_task_manager.join() | |
compute_task_manager.join() | |
bbs.tic("finalize heap") | |
bbs.rh.finalize() | |
bbs.toc() | |
return bbs.rh.D, bbs.rh.I | |