Draken007's picture
Upload 7228 files
2a0bc63 verified
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from multiprocessing.pool import ThreadPool
import faiss
from typing import List, Tuple
from . import rpc
############################################################
# Server implementation
############################################################
class SearchServer(rpc.Server):
""" Assign version that can be exposed via RPC """
def __init__(self, s: int, index: faiss.Index):
rpc.Server.__init__(self, s)
self.index = index
self.index_ivf = faiss.extract_index_ivf(index)
def set_nprobe(self, nprobe: int) -> int:
""" set nprobe field """
self.index_ivf.nprobe = nprobe
def get_ntotal(self) -> int:
return self.index.ntotal
def __getattr__(self, f):
# all other functions get forwarded to the index
return getattr(self.index, f)
def run_index_server(index: faiss.Index, port: int, v6: bool = False):
""" serve requests for that index forerver """
rpc.run_server(
lambda s: SearchServer(s, index),
port, v6=v6)
############################################################
# Client implementation
############################################################
class ClientIndex:
"""manages a set of distance sub-indexes. The sub_indexes search a
subset of the inverted lists. Searches are merged afterwards
"""
def __init__(self, machine_ports: List[Tuple[str, int]], v6: bool = False):
""" connect to a series of (host, port) pairs """
self.sub_indexes = []
for machine, port in machine_ports:
self.sub_indexes.append(rpc.Client(machine, port, v6))
self.ni = len(self.sub_indexes)
# pool of threads. Each thread manages one sub-index.
self.pool = ThreadPool(self.ni)
# test connection...
self.ntotal = self.get_ntotal()
self.verbose = False
def set_nprobe(self, nprobe: int) -> None:
self.pool.map(
lambda idx: idx.set_nprobe(nprobe),
self.sub_indexes
)
def set_omp_num_threads(self, nt: int) -> None:
self.pool.map(
lambda idx: idx.set_omp_num_threads(nt),
self.sub_indexes
)
def get_ntotal(self) -> None:
return sum(self.pool.map(
lambda idx: idx.get_ntotal(),
self.sub_indexes
))
def search(self, x, k: int):
rh = faiss.ResultHeap(x.shape[0], k)
for Di, Ii in self.pool.imap(lambda idx: idx.search(x, k), self.sub_indexes):
rh.add_result(Di, Ii)
rh.finalize()
return rh.D, rh.I