Spaces:
Running
Running
File size: 5,268 Bytes
2a0bc63 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# @nolint
# not linting this file because it imports * from swigfaiss, which
# causes a ton of useless warnings.
import numpy as np
import array
import warnings
from faiss.loader import *
###########################################
# Utility to add a deprecation warning to
# classes from the SWIG interface
###########################################
def _make_deprecated_swig_class(deprecated_name, base_name):
"""
Dynamically construct deprecated classes as wrappers around renamed ones
The deprecation warning added in their __new__-method will trigger upon
construction of an instance of the class, but only once per session.
We do this here (in __init__.py) because the base classes are defined in
the SWIG interface, making it cumbersome to add the deprecation there.
Parameters
----------
deprecated_name : string
Name of the class to be deprecated; _not_ present in SWIG interface.
base_name : string
Name of the class that is replacing deprecated_name; must already be
imported into the current namespace.
Returns
-------
None
However, the deprecated class gets added to the faiss namespace
"""
base_class = globals()[base_name]
def new_meth(cls, *args, **kwargs):
msg = f"The class faiss.{deprecated_name} is deprecated in favour of faiss.{base_name}!"
warnings.warn(msg, DeprecationWarning, stacklevel=2)
instance = super(base_class, cls).__new__(cls, *args, **kwargs)
return instance
# three-argument version of "type" uses (name, tuple-of-bases, dict-of-attributes)
klazz = type(deprecated_name, (base_class,), {"__new__": new_meth})
# this ends up adding the class to the "faiss" namespace, in a way that it
# is available both through "import faiss" and "from faiss import *"
globals()[deprecated_name] = klazz
###########################################
# numpy array / std::vector conversions
###########################################
sizeof_long = array.array('l').itemsize
deprecated_name_map = {
# deprecated: replacement
'Float': 'Float32',
'Double': 'Float64',
'Char': 'Int8',
'Int': 'Int32',
'Long': 'Int32' if sizeof_long == 4 else 'Int64',
'LongLong': 'Int64',
'Byte': 'UInt8',
# previously misspelled variant
'Uint64': 'UInt64',
}
for depr_prefix, base_prefix in deprecated_name_map.items():
_make_deprecated_swig_class(depr_prefix + "Vector", base_prefix + "Vector")
# same for the three legacy *VectorVector classes
if depr_prefix in ['Float', 'Long', 'Byte']:
_make_deprecated_swig_class(depr_prefix + "VectorVector",
base_prefix + "VectorVector")
# mapping from vector names in swigfaiss.swig and the numpy dtype names
# TODO: once deprecated classes are removed, remove the dict and just use .lower() below
vector_name_map = {
'Float32': 'float32',
'Float64': 'float64',
'Int8': 'int8',
'Int16': 'int16',
'Int32': 'int32',
'Int64': 'int64',
'UInt8': 'uint8',
'UInt16': 'uint16',
'UInt32': 'uint32',
'UInt64': 'uint64',
**{k: v.lower() for k, v in deprecated_name_map.items()}
}
def vector_to_array(v):
""" convert a C++ vector to a numpy array """
classname = v.__class__.__name__
if classname.startswith('AlignedTable'):
return AlignedTable_to_array(v)
assert classname.endswith('Vector')
dtype = np.dtype(vector_name_map[classname[:-6]])
a = np.empty(v.size(), dtype=dtype)
if v.size() > 0:
memcpy(swig_ptr(a), v.data(), a.nbytes)
return a
def vector_float_to_array(v):
return vector_to_array(v)
def copy_array_to_vector(a, v):
""" copy a numpy array to a vector """
n, = a.shape
classname = v.__class__.__name__
assert classname.endswith('Vector')
dtype = np.dtype(vector_name_map[classname[:-6]])
assert dtype == a.dtype, (
'cannot copy a %s array to a %s (should be %s)' % (
a.dtype, classname, dtype))
v.resize(n)
if n > 0:
memcpy(v.data(), swig_ptr(a), a.nbytes)
# same for AlignedTable
def copy_array_to_AlignedTable(a, v):
n, = a.shape
# TODO check class name
assert v.itemsize() == a.itemsize
v.resize(n)
if n > 0:
memcpy(v.get(), swig_ptr(a), a.nbytes)
def array_to_AlignedTable(a):
if a.dtype == 'uint16':
v = AlignedTableUint16(a.size)
elif a.dtype == 'uint8':
v = AlignedTableUint8(a.size)
else:
assert False
copy_array_to_AlignedTable(a, v)
return v
def AlignedTable_to_array(v):
""" convert an AlignedTable to a numpy array """
classname = v.__class__.__name__
assert classname.startswith('AlignedTable')
dtype = classname[12:].lower()
a = np.empty(v.size(), dtype=dtype)
if a.size > 0:
memcpy(swig_ptr(a), v.data(), a.nbytes)
return a
|