from pathlib import Path import numpy as np from esm.utils.types import PathLike class LSHTable: def __init__(self, n_bits: int, dim: int, hyperplanes: np.ndarray | None = None): if hyperplanes is None: hyperplanes = np.random.randn(n_bits, dim) hyperplanes = hyperplanes / np.linalg.norm( hyperplanes, axis=-1, keepdims=True ) else: assert hyperplanes.shape == (n_bits, dim), ( hyperplanes.shape, (n_bits, dim), ) assert hyperplanes is not None self.hyperplanes: np.ndarray = hyperplanes self.values = 1 << np.arange(n_bits) def __call__(self, array, tokenize: bool = True): similarity = self.hyperplanes @ array.T bits = np.where(similarity >= 0, 1, 0) if tokenize: tokens = bits.T @ self.values return tokens else: return bits.T class LSHTokenized: def __init__( self, n_bits: int, dim: int, num_tables: int = 1, filepath: PathLike | None = None, allow_create_hyperplanes: bool = False, # set this if you want the lsh to allow creation of hyperplanes ): table_hyperplanes = None if filepath is not None: filepath = Path(filepath) if not filepath.exists(): raise FileNotFoundError(filepath) table_hyperplanes = np.load(filepath) # type: ignore for i in range(num_tables): assert str(i) in table_hyperplanes, f"Missing hyperplane for table {i}" elif not allow_create_hyperplanes: raise RuntimeError( "Not allowed to create hyperplanes but no filepath provided" ) self.tables = [ LSHTable( n_bits, dim, table_hyperplanes[str(i)] if table_hyperplanes is not None else None, ) for i in range(num_tables) ] def write_hyperplanes(self, filepath: PathLike): hyperplanes: dict[str, np.ndarray] = { # type: ignore str(i): table.hyperplanes for i, table in enumerate(self.tables) } np.savez(filepath, **hyperplanes) def __call__(self, array): tokens = np.stack([table(array) for table in self.tables], 1) return tokens class LSHBitstream: def __init__( self, n_bits: int, dim: int, filepath: PathLike | None = None, allow_create_hyperplanes: bool = False, # set this if you want the lsh to allow creation of hyperplanes ): table_hyperplanes = None if filepath is not None: filepath = Path(filepath) if not filepath.exists(): raise FileNotFoundError(filepath) table_hyperplanes = np.load(filepath) elif not allow_create_hyperplanes: raise RuntimeError( "Not allowed to create hyperplanes but no filepath provided" ) self.table = LSHTable( n_bits, dim, table_hyperplanes if table_hyperplanes is not None else None ) def write_hyperplanes(self, filepath: PathLike): np.save(filepath, self.table.hyperplanes) def __call__(self, array): return self.table(array, tokenize=False)