Draken007's picture
Upload 7228 files
2a0bc63 verified
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import time
import pickle
import os
import logging
from multiprocessing.pool import ThreadPool
import threading
import _thread
from queue import Queue
import traceback
import datetime
import numpy as np
import faiss
from faiss.contrib.inspect_tools import get_invlist
class BigBatchSearcher:
"""
Object that manages all the data related to the computation
except the actual within-bucket matching and the organization of the
computation (parallel or not)
"""
def __init__(
self,
index, xq, k,
verbose=0,
use_float16=False):
# verbosity
self.verbose = verbose
self.tictoc = []
self.xq = xq
self.index = index
self.use_float16 = use_float16
keep_max = faiss.is_similarity_metric(index.metric_type)
self.rh = faiss.ResultHeap(len(xq), k, keep_max=keep_max)
self.t_accu = [0] * 6
self.t_display = self.t0 = time.time()
def start_t_accu(self):
self.t_accu_t0 = time.time()
def stop_t_accu(self, n):
self.t_accu[n] += time.time() - self.t_accu_t0
def tic(self, name):
self.tictoc = (name, time.time())
if self.verbose > 0:
print(name, end="\r", flush=True)
def toc(self):
name, t0 = self.tictoc
dt = time.time() - t0
if self.verbose > 0:
print(f"{name}: {dt:.3f} s")
return dt
def report(self, l):
if self.verbose == 1 or (
self.verbose == 2 and (
l > 1000 and time.time() < self.t_display + 1.0
)
):
return
t = time.time() - self.t0
print(
f"[{t:.1f} s] list {l}/{self.index.nlist} "
f"times prep q {self.t_accu[0]:.3f} prep b {self.t_accu[1]:.3f} "
f"comp {self.t_accu[2]:.3f} res {self.t_accu[3]:.3f} "
f"wait in {self.t_accu[4]:.3f} "
f"wait out {self.t_accu[5]:.3f} "
f"eta {datetime.timedelta(seconds=t*self.index.nlist/(l+1)-t)} "
f"mem {faiss.get_mem_usage_kb()}",
end="\r" if self.verbose <= 2 else "\n",
flush=True,
)
self.t_display = time.time()
def coarse_quantization(self):
self.tic("coarse quantization")
bs = 65536
nq = len(self.xq)
q_assign = np.empty((nq, self.index.nprobe), dtype='int32')
for i0 in range(0, nq, bs):
i1 = min(nq, i0 + bs)
q_dis_i, q_assign_i = self.index.quantizer.search(
self.xq[i0:i1], self.index.nprobe)
# q_dis[i0:i1] = q_dis_i
q_assign[i0:i1] = q_assign_i
self.toc()
self.q_assign = q_assign
def reorder_assign(self):
self.tic("bucket sort")
q_assign = self.q_assign
q_assign += 1 # move -1 -> 0
self.bucket_lims = faiss.matrix_bucket_sort_inplace(
self.q_assign, nbucket=self.index.nlist + 1, nt=16)
self.query_ids = self.q_assign.ravel()
if self.verbose > 0:
print(' number of -1s:', self.bucket_lims[1])
self.bucket_lims = self.bucket_lims[1:] # shift back to ignore -1s
del self.q_assign # inplace so let's forget about the old version...
self.toc()
def prepare_bucket(self, l):
""" prepare the queries and database items for bucket l"""
t0 = time.time()
index = self.index
# prepare queries
i0, i1 = self.bucket_lims[l], self.bucket_lims[l + 1]
q_subset = self.query_ids[i0:i1]
xq_l = self.xq[q_subset]
if self.by_residual:
xq_l = xq_l - index.quantizer.reconstruct(l)
t1 = time.time()
# prepare database side
list_ids, xb_l = get_invlist(index.invlists, l)
if self.decode_func is None:
xb_l = xb_l.ravel()
else:
xb_l = self.decode_func(xb_l)
if self.use_float16:
xb_l = xb_l.astype('float16')
xq_l = xq_l.astype('float16')
t2 = time.time()
self.t_accu[0] += t1 - t0
self.t_accu[1] += t2 - t1
return q_subset, xq_l, list_ids, xb_l
def add_results_to_heap(self, q_subset, D, list_ids, I):
"""add the bucket results to the heap structure"""
if D is None:
return
t0 = time.time()
if I is None:
I = list_ids
else:
I = list_ids[I]
self.rh.add_result_subset(q_subset, D, I)
self.t_accu[3] += time.time() - t0
def sizes_in_checkpoint(self):
return (self.xq.shape, self.index.nprobe, self.index.nlist)
def write_checkpoint(self, fname, completed):
# write to temp file then move to final file
tmpname = fname + ".tmp"
with open(tmpname, "wb") as f:
pickle.dump(
{
"sizes": self.sizes_in_checkpoint(),
"completed": completed,
"rh": (self.rh.D, self.rh.I),
}, f, -1)
os.replace(tmpname, fname)
def read_checkpoint(self, fname):
with open(fname, "rb") as f:
ckp = pickle.load(f)
assert ckp["sizes"] == self.sizes_in_checkpoint()
self.rh.D[:] = ckp["rh"][0]
self.rh.I[:] = ckp["rh"][1]
return ckp["completed"]
class BlockComputer:
""" computation within one bucket """
def __init__(
self,
index,
method="knn_function",
pairwise_distances=faiss.pairwise_distances,
knn=faiss.knn):
self.index = index
if index.__class__ == faiss.IndexIVFFlat:
index_help = faiss.IndexFlat(index.d, index.metric_type)
decode_func = lambda x: x.view("float32")
by_residual = False
elif index.__class__ == faiss.IndexIVFPQ:
index_help = faiss.IndexPQ(
index.d, index.pq.M, index.pq.nbits, index.metric_type)
index_help.pq = index.pq
decode_func = index_help.pq.decode
index_help.is_trained = True
by_residual = index.by_residual
elif index.__class__ == faiss.IndexIVFScalarQuantizer:
index_help = faiss.IndexScalarQuantizer(
index.d, index.sq.qtype, index.metric_type)
index_help.sq = index.sq
decode_func = index_help.sq.decode
index_help.is_trained = True
by_residual = index.by_residual
else:
raise RuntimeError(f"index type {index.__class__} not supported")
self.index_help = index_help
self.decode_func = None if method == "index" else decode_func
self.by_residual = by_residual
self.method = method
self.pairwise_distances = pairwise_distances
self.knn = knn
def block_search(self, xq_l, xb_l, list_ids, k, **extra_args):
metric_type = self.index.metric_type
if xq_l.size == 0 or xb_l.size == 0:
D = I = None
elif self.method == "index":
faiss.copy_array_to_vector(xb_l, self.index_help.codes)
self.index_help.ntotal = len(list_ids)
D, I = self.index_help.search(xq_l, k)
elif self.method == "pairwise_distances":
# TODO implement blockwise to avoid mem blowup
D = self.pairwise_distances(xq_l, xb_l, metric=metric_type)
I = None
elif self.method == "knn_function":
D, I = self.knn(xq_l, xb_l, k, metric=metric_type, **extra_args)
return D, I
def big_batch_search(
index, xq, k,
method="knn_function",
pairwise_distances=faiss.pairwise_distances,
knn=faiss.knn,
verbose=0,
threaded=0,
use_float16=False,
prefetch_threads=1,
computation_threads=1,
q_assign=None,
checkpoint=None,
checkpoint_freq=7200,
start_list=0,
end_list=None,
crash_at=-1
):
"""
Search queries xq in the IVF index, with a search function that collects
batches of query vectors per inverted list. This can be faster than the
regular search indexes.
Supports IVFFlat, IVFPQ and IVFScalarQuantizer.
Supports three computation methods:
method = "index":
build a flat index and populate it separately for each index
method = "pairwise_distances":
decompress codes and compute all pairwise distances for the queries
and index and add result to heap
method = "knn_function":
decompress codes and compute knn results for the queries
threaded=0: sequential execution
threaded=1: prefetch next bucket while computing the current one
threaded=2: prefetch prefetch_threads buckets at a time.
compute_threads>1: the knn function will get an additional thread_no that
tells which worker should handle this.
In threaded mode, the computation is tiled with the bucket perparation and
the writeback of results (useful to maximize GPU utilization).
use_float16: convert all matrices to float16 (faster for GPU gemm)
q_assign: override coarse assignment, should be a matrix of size nq * nprobe
checkpointing (only for threaded > 1):
checkpoint: file where the checkpoints are stored
checkpoint_freq: when to perform checkpoinging. Should be a multiple of threaded
start_list, end_list: process only a subset of invlists
"""
nprobe = index.nprobe
assert method in ("index", "pairwise_distances", "knn_function")
mem_queries = xq.nbytes
mem_assign = len(xq) * nprobe * np.dtype('int32').itemsize
mem_res = len(xq) * k * (
np.dtype('int64').itemsize
+ np.dtype('float32').itemsize
)
mem_tot = mem_queries + mem_assign + mem_res
if verbose > 0:
logging.info(
f"memory: queries {mem_queries} assign {mem_assign} "
f"result {mem_res} total {mem_tot} = {mem_tot / (1<<30):.3f} GiB"
)
bbs = BigBatchSearcher(
index, xq, k,
verbose=verbose,
use_float16=use_float16
)
comp = BlockComputer(
index,
method=method,
pairwise_distances=pairwise_distances,
knn=knn
)
bbs.decode_func = comp.decode_func
bbs.by_residual = comp.by_residual
if q_assign is None:
bbs.coarse_quantization()
else:
bbs.q_assign = q_assign
bbs.reorder_assign()
if end_list is None:
end_list = index.nlist
completed = set()
if checkpoint is not None:
assert (start_list, end_list) == (0, index.nlist)
if os.path.exists(checkpoint):
logging.info(f"recovering checkpoint: {checkpoint}")
completed = bbs.read_checkpoint(checkpoint)
logging.info(f" already completed: {len(completed)}")
else:
logging.info("no checkpoint: starting from scratch")
if threaded == 0:
# simple sequential version
for l in range(start_list, end_list):
bbs.report(l)
q_subset, xq_l, list_ids, xb_l = bbs.prepare_bucket(l)
t0i = time.time()
D, I = comp.block_search(xq_l, xb_l, list_ids, k)
bbs.t_accu[2] += time.time() - t0i
bbs.add_results_to_heap(q_subset, D, list_ids, I)
elif threaded == 1:
# parallel version with granularity 1
def add_results_and_prefetch(to_add, l):
""" perform the addition for the previous bucket and
prefetch the next (if applicable) """
if to_add is not None:
bbs.add_results_to_heap(*to_add)
if l < index.nlist:
return bbs.prepare_bucket(l)
prefetched_bucket = bbs.prepare_bucket(start_list)
to_add = None
pool = ThreadPool(1)
for l in range(start_list, end_list):
bbs.report(l)
prefetched_bucket_a = pool.apply_async(
add_results_and_prefetch, (to_add, l + 1))
q_subset, xq_l, list_ids, xb_l = prefetched_bucket
bbs.start_t_accu()
D, I = comp.block_search(xq_l, xb_l, list_ids, k)
bbs.stop_t_accu(2)
to_add = q_subset, D, list_ids, I
bbs.start_t_accu()
prefetched_bucket = prefetched_bucket_a.get()
bbs.stop_t_accu(4)
bbs.add_results_to_heap(*to_add)
pool.close()
else:
def task_manager_thread(
task,
pool_size,
start_task,
end_task,
completed,
output_queue,
input_queue,
):
try:
with ThreadPool(pool_size) as pool:
res = [pool.apply_async(
task,
args=(i, output_queue, input_queue))
for i in range(start_task, end_task)
if i not in completed]
for r in res:
r.get()
pool.close()
pool.join()
output_queue.put(None)
except:
traceback.print_exc()
_thread.interrupt_main()
raise
def task_manager(*args):
task_manager = threading.Thread(
target=task_manager_thread,
args=args,
)
task_manager.daemon = True
task_manager.start()
return task_manager
def prepare_task(task_id, output_queue, input_queue=None):
try:
logging.info(f"Prepare start: {task_id}")
q_subset, xq_l, list_ids, xb_l = bbs.prepare_bucket(task_id)
output_queue.put((task_id, q_subset, xq_l, list_ids, xb_l))
logging.info(f"Prepare end: {task_id}")
except:
traceback.print_exc()
_thread.interrupt_main()
raise
def compute_task(task_id, output_queue, input_queue):
try:
logging.info(f"Compute start: {task_id}")
t_wait_out = 0
while True:
t0 = time.time()
logging.info(f'Compute input: task {task_id}')
input_value = input_queue.get()
t_wait_in = time.time() - t0
if input_value is None:
# signal for other compute tasks
input_queue.put(None)
break
centroid, q_subset, xq_l, list_ids, xb_l = input_value
logging.info(f'Compute work: task {task_id}, centroid {centroid}')
t0 = time.time()
if computation_threads > 1:
D, I = comp.block_search(
xq_l, xb_l, list_ids, k, thread_id=task_id
)
else:
D, I = comp.block_search(xq_l, xb_l, list_ids, k)
t_compute = time.time() - t0
logging.info(f'Compute output: task {task_id}, centroid {centroid}')
t0 = time.time()
output_queue.put(
(centroid, t_wait_in, t_wait_out, t_compute, q_subset, D, list_ids, I)
)
t_wait_out = time.time() - t0
logging.info(f"Compute end: {task_id}")
except:
traceback.print_exc()
_thread.interrupt_main()
raise
prepare_to_compute_queue = Queue(2)
compute_to_main_queue = Queue(2)
compute_task_manager = task_manager(
compute_task,
computation_threads,
0,
computation_threads,
set(),
compute_to_main_queue,
prepare_to_compute_queue,
)
prepare_task_manager = task_manager(
prepare_task,
prefetch_threads,
start_list,
end_list,
completed,
prepare_to_compute_queue,
None,
)
t_checkpoint = time.time()
while True:
logging.info("Waiting for result")
value = compute_to_main_queue.get()
if not value:
break
centroid, t_wait_in, t_wait_out, t_compute, q_subset, D, list_ids, I = value
# to test checkpointing
if centroid == crash_at:
1 / 0
bbs.t_accu[2] += t_compute
bbs.t_accu[4] += t_wait_in
bbs.t_accu[5] += t_wait_out
logging.info(f"Adding to heap start: centroid {centroid}")
bbs.add_results_to_heap(q_subset, D, list_ids, I)
logging.info(f"Adding to heap end: centroid {centroid}")
completed.add(centroid)
bbs.report(centroid)
if checkpoint is not None:
if time.time() - t_checkpoint > checkpoint_freq:
logging.info("writing checkpoint")
bbs.write_checkpoint(checkpoint, completed)
t_checkpoint = time.time()
prepare_task_manager.join()
compute_task_manager.join()
bbs.tic("finalize heap")
bbs.rh.finalize()
bbs.toc()
return bbs.rh.D, bbs.rh.I