geochatbot / llm /Lib /site-packages /faiss /array_conversions.py
Draken007's picture
Upload 7228 files
2a0bc63 verified
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# @nolint
# not linting this file because it imports * from swigfaiss, which
# causes a ton of useless warnings.
import numpy as np
import array
import warnings
from faiss.loader import *
###########################################
# Utility to add a deprecation warning to
# classes from the SWIG interface
###########################################
def _make_deprecated_swig_class(deprecated_name, base_name):
"""
Dynamically construct deprecated classes as wrappers around renamed ones
The deprecation warning added in their __new__-method will trigger upon
construction of an instance of the class, but only once per session.
We do this here (in __init__.py) because the base classes are defined in
the SWIG interface, making it cumbersome to add the deprecation there.
Parameters
----------
deprecated_name : string
Name of the class to be deprecated; _not_ present in SWIG interface.
base_name : string
Name of the class that is replacing deprecated_name; must already be
imported into the current namespace.
Returns
-------
None
However, the deprecated class gets added to the faiss namespace
"""
base_class = globals()[base_name]
def new_meth(cls, *args, **kwargs):
msg = f"The class faiss.{deprecated_name} is deprecated in favour of faiss.{base_name}!"
warnings.warn(msg, DeprecationWarning, stacklevel=2)
instance = super(base_class, cls).__new__(cls, *args, **kwargs)
return instance
# three-argument version of "type" uses (name, tuple-of-bases, dict-of-attributes)
klazz = type(deprecated_name, (base_class,), {"__new__": new_meth})
# this ends up adding the class to the "faiss" namespace, in a way that it
# is available both through "import faiss" and "from faiss import *"
globals()[deprecated_name] = klazz
###########################################
# numpy array / std::vector conversions
###########################################
sizeof_long = array.array('l').itemsize
deprecated_name_map = {
# deprecated: replacement
'Float': 'Float32',
'Double': 'Float64',
'Char': 'Int8',
'Int': 'Int32',
'Long': 'Int32' if sizeof_long == 4 else 'Int64',
'LongLong': 'Int64',
'Byte': 'UInt8',
# previously misspelled variant
'Uint64': 'UInt64',
}
for depr_prefix, base_prefix in deprecated_name_map.items():
_make_deprecated_swig_class(depr_prefix + "Vector", base_prefix + "Vector")
# same for the three legacy *VectorVector classes
if depr_prefix in ['Float', 'Long', 'Byte']:
_make_deprecated_swig_class(depr_prefix + "VectorVector",
base_prefix + "VectorVector")
# mapping from vector names in swigfaiss.swig and the numpy dtype names
# TODO: once deprecated classes are removed, remove the dict and just use .lower() below
vector_name_map = {
'Float32': 'float32',
'Float64': 'float64',
'Int8': 'int8',
'Int16': 'int16',
'Int32': 'int32',
'Int64': 'int64',
'UInt8': 'uint8',
'UInt16': 'uint16',
'UInt32': 'uint32',
'UInt64': 'uint64',
**{k: v.lower() for k, v in deprecated_name_map.items()}
}
def vector_to_array(v):
""" convert a C++ vector to a numpy array """
classname = v.__class__.__name__
if classname.startswith('AlignedTable'):
return AlignedTable_to_array(v)
assert classname.endswith('Vector')
dtype = np.dtype(vector_name_map[classname[:-6]])
a = np.empty(v.size(), dtype=dtype)
if v.size() > 0:
memcpy(swig_ptr(a), v.data(), a.nbytes)
return a
def vector_float_to_array(v):
return vector_to_array(v)
def copy_array_to_vector(a, v):
""" copy a numpy array to a vector """
n, = a.shape
classname = v.__class__.__name__
assert classname.endswith('Vector')
dtype = np.dtype(vector_name_map[classname[:-6]])
assert dtype == a.dtype, (
'cannot copy a %s array to a %s (should be %s)' % (
a.dtype, classname, dtype))
v.resize(n)
if n > 0:
memcpy(v.data(), swig_ptr(a), a.nbytes)
# same for AlignedTable
def copy_array_to_AlignedTable(a, v):
n, = a.shape
# TODO check class name
assert v.itemsize() == a.itemsize
v.resize(n)
if n > 0:
memcpy(v.get(), swig_ptr(a), a.nbytes)
def array_to_AlignedTable(a):
if a.dtype == 'uint16':
v = AlignedTableUint16(a.size)
elif a.dtype == 'uint8':
v = AlignedTableUint8(a.size)
else:
assert False
copy_array_to_AlignedTable(a, v)
return v
def AlignedTable_to_array(v):
""" convert an AlignedTable to a numpy array """
classname = v.__class__.__name__
assert classname.startswith('AlignedTable')
dtype = classname[12:].lower()
a = np.empty(v.size(), dtype=dtype)
if a.size > 0:
memcpy(swig_ptr(a), v.data(), a.nbytes)
return a