Spaces:
Sleeping
Sleeping
from pathlib import Path | |
import numpy as np | |
from esm.utils.types import PathLike | |
class LSHTable: | |
def __init__(self, n_bits: int, dim: int, hyperplanes: np.ndarray | None = None): | |
if hyperplanes is None: | |
hyperplanes = np.random.randn(n_bits, dim) | |
hyperplanes = hyperplanes / np.linalg.norm( | |
hyperplanes, axis=-1, keepdims=True | |
) | |
else: | |
assert hyperplanes.shape == (n_bits, dim), ( | |
hyperplanes.shape, | |
(n_bits, dim), | |
) | |
assert hyperplanes is not None | |
self.hyperplanes: np.ndarray = hyperplanes | |
self.values = 1 << np.arange(n_bits) | |
def __call__(self, array, tokenize: bool = True): | |
similarity = self.hyperplanes @ array.T | |
bits = np.where(similarity >= 0, 1, 0) | |
if tokenize: | |
tokens = bits.T @ self.values | |
return tokens | |
else: | |
return bits.T | |
class LSHTokenized: | |
def __init__( | |
self, | |
n_bits: int, | |
dim: int, | |
num_tables: int = 1, | |
filepath: PathLike | None = None, | |
allow_create_hyperplanes: bool = False, # set this if you want the lsh to allow creation of hyperplanes | |
): | |
table_hyperplanes = None | |
if filepath is not None: | |
filepath = Path(filepath) | |
if not filepath.exists(): | |
raise FileNotFoundError(filepath) | |
table_hyperplanes = np.load(filepath) # type: ignore | |
for i in range(num_tables): | |
assert str(i) in table_hyperplanes, f"Missing hyperplane for table {i}" | |
elif not allow_create_hyperplanes: | |
raise RuntimeError( | |
"Not allowed to create hyperplanes but no filepath provided" | |
) | |
self.tables = [ | |
LSHTable( | |
n_bits, | |
dim, | |
table_hyperplanes[str(i)] if table_hyperplanes is not None else None, | |
) | |
for i in range(num_tables) | |
] | |
def write_hyperplanes(self, filepath: PathLike): | |
hyperplanes: dict[str, np.ndarray] = { # type: ignore | |
str(i): table.hyperplanes for i, table in enumerate(self.tables) | |
} | |
np.savez(filepath, **hyperplanes) | |
def __call__(self, array): | |
tokens = np.stack([table(array) for table in self.tables], 1) | |
return tokens | |
class LSHBitstream: | |
def __init__( | |
self, | |
n_bits: int, | |
dim: int, | |
filepath: PathLike | None = None, | |
allow_create_hyperplanes: bool = False, # set this if you want the lsh to allow creation of hyperplanes | |
): | |
table_hyperplanes = None | |
if filepath is not None: | |
filepath = Path(filepath) | |
if not filepath.exists(): | |
raise FileNotFoundError(filepath) | |
table_hyperplanes = np.load(filepath) | |
elif not allow_create_hyperplanes: | |
raise RuntimeError( | |
"Not allowed to create hyperplanes but no filepath provided" | |
) | |
self.table = LSHTable( | |
n_bits, dim, table_hyperplanes if table_hyperplanes is not None else None | |
) | |
def write_hyperplanes(self, filepath: PathLike): | |
np.save(filepath, self.table.hyperplanes) | |
def __call__(self, array): | |
return self.table(array, tokenize=False) | |