Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
# @nolint | |
# not linting this file because it imports * from swigfaiss, which | |
# causes a ton of useless warnings. | |
import numpy as np | |
import array | |
import warnings | |
from faiss.loader import * | |
########################################### | |
# Utility to add a deprecation warning to | |
# classes from the SWIG interface | |
########################################### | |
def _make_deprecated_swig_class(deprecated_name, base_name): | |
""" | |
Dynamically construct deprecated classes as wrappers around renamed ones | |
The deprecation warning added in their __new__-method will trigger upon | |
construction of an instance of the class, but only once per session. | |
We do this here (in __init__.py) because the base classes are defined in | |
the SWIG interface, making it cumbersome to add the deprecation there. | |
Parameters | |
---------- | |
deprecated_name : string | |
Name of the class to be deprecated; _not_ present in SWIG interface. | |
base_name : string | |
Name of the class that is replacing deprecated_name; must already be | |
imported into the current namespace. | |
Returns | |
------- | |
None | |
However, the deprecated class gets added to the faiss namespace | |
""" | |
base_class = globals()[base_name] | |
def new_meth(cls, *args, **kwargs): | |
msg = f"The class faiss.{deprecated_name} is deprecated in favour of faiss.{base_name}!" | |
warnings.warn(msg, DeprecationWarning, stacklevel=2) | |
instance = super(base_class, cls).__new__(cls, *args, **kwargs) | |
return instance | |
# three-argument version of "type" uses (name, tuple-of-bases, dict-of-attributes) | |
klazz = type(deprecated_name, (base_class,), {"__new__": new_meth}) | |
# this ends up adding the class to the "faiss" namespace, in a way that it | |
# is available both through "import faiss" and "from faiss import *" | |
globals()[deprecated_name] = klazz | |
########################################### | |
# numpy array / std::vector conversions | |
########################################### | |
sizeof_long = array.array('l').itemsize | |
deprecated_name_map = { | |
# deprecated: replacement | |
'Float': 'Float32', | |
'Double': 'Float64', | |
'Char': 'Int8', | |
'Int': 'Int32', | |
'Long': 'Int32' if sizeof_long == 4 else 'Int64', | |
'LongLong': 'Int64', | |
'Byte': 'UInt8', | |
# previously misspelled variant | |
'Uint64': 'UInt64', | |
} | |
for depr_prefix, base_prefix in deprecated_name_map.items(): | |
_make_deprecated_swig_class(depr_prefix + "Vector", base_prefix + "Vector") | |
# same for the three legacy *VectorVector classes | |
if depr_prefix in ['Float', 'Long', 'Byte']: | |
_make_deprecated_swig_class(depr_prefix + "VectorVector", | |
base_prefix + "VectorVector") | |
# mapping from vector names in swigfaiss.swig and the numpy dtype names | |
# TODO: once deprecated classes are removed, remove the dict and just use .lower() below | |
vector_name_map = { | |
'Float32': 'float32', | |
'Float64': 'float64', | |
'Int8': 'int8', | |
'Int16': 'int16', | |
'Int32': 'int32', | |
'Int64': 'int64', | |
'UInt8': 'uint8', | |
'UInt16': 'uint16', | |
'UInt32': 'uint32', | |
'UInt64': 'uint64', | |
**{k: v.lower() for k, v in deprecated_name_map.items()} | |
} | |
def vector_to_array(v): | |
""" convert a C++ vector to a numpy array """ | |
classname = v.__class__.__name__ | |
if classname.startswith('AlignedTable'): | |
return AlignedTable_to_array(v) | |
assert classname.endswith('Vector') | |
dtype = np.dtype(vector_name_map[classname[:-6]]) | |
a = np.empty(v.size(), dtype=dtype) | |
if v.size() > 0: | |
memcpy(swig_ptr(a), v.data(), a.nbytes) | |
return a | |
def vector_float_to_array(v): | |
return vector_to_array(v) | |
def copy_array_to_vector(a, v): | |
""" copy a numpy array to a vector """ | |
n, = a.shape | |
classname = v.__class__.__name__ | |
assert classname.endswith('Vector') | |
dtype = np.dtype(vector_name_map[classname[:-6]]) | |
assert dtype == a.dtype, ( | |
'cannot copy a %s array to a %s (should be %s)' % ( | |
a.dtype, classname, dtype)) | |
v.resize(n) | |
if n > 0: | |
memcpy(v.data(), swig_ptr(a), a.nbytes) | |
# same for AlignedTable | |
def copy_array_to_AlignedTable(a, v): | |
n, = a.shape | |
# TODO check class name | |
assert v.itemsize() == a.itemsize | |
v.resize(n) | |
if n > 0: | |
memcpy(v.get(), swig_ptr(a), a.nbytes) | |
def array_to_AlignedTable(a): | |
if a.dtype == 'uint16': | |
v = AlignedTableUint16(a.size) | |
elif a.dtype == 'uint8': | |
v = AlignedTableUint8(a.size) | |
else: | |
assert False | |
copy_array_to_AlignedTable(a, v) | |
return v | |
def AlignedTable_to_array(v): | |
""" convert an AlignedTable to a numpy array """ | |
classname = v.__class__.__name__ | |
assert classname.startswith('AlignedTable') | |
dtype = classname[12:].lower() | |
a = np.empty(v.size(), dtype=dtype) | |
if a.size > 0: | |
memcpy(swig_ptr(a), v.data(), a.nbytes) | |
return a | |