File size: 3,384 Bytes
224a33f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from pathlib import Path

import numpy as np

from esm.utils.types import PathLike


class LSHTable:
    def __init__(self, n_bits: int, dim: int, hyperplanes: np.ndarray | None = None):
        if hyperplanes is None:
            hyperplanes = np.random.randn(n_bits, dim)
            hyperplanes = hyperplanes / np.linalg.norm(
                hyperplanes, axis=-1, keepdims=True
            )
        else:
            assert hyperplanes.shape == (n_bits, dim), (
                hyperplanes.shape,
                (n_bits, dim),
            )
        assert hyperplanes is not None
        self.hyperplanes: np.ndarray = hyperplanes
        self.values = 1 << np.arange(n_bits)

    def __call__(self, array, tokenize: bool = True):
        similarity = self.hyperplanes @ array.T
        bits = np.where(similarity >= 0, 1, 0)
        if tokenize:
            tokens = bits.T @ self.values
            return tokens
        else:
            return bits.T


class LSHTokenized:
    def __init__(
        self,
        n_bits: int,
        dim: int,
        num_tables: int = 1,
        filepath: PathLike | None = None,
        allow_create_hyperplanes: bool = False,  # set this if you want the lsh to allow creation of hyperplanes
    ):
        table_hyperplanes = None
        if filepath is not None:
            filepath = Path(filepath)
            if not filepath.exists():
                raise FileNotFoundError(filepath)
            table_hyperplanes = np.load(filepath)  # type: ignore
            for i in range(num_tables):
                assert str(i) in table_hyperplanes, f"Missing hyperplane for table {i}"
        elif not allow_create_hyperplanes:
            raise RuntimeError(
                "Not allowed to create hyperplanes but no filepath provided"
            )

        self.tables = [
            LSHTable(
                n_bits,
                dim,
                table_hyperplanes[str(i)] if table_hyperplanes is not None else None,
            )
            for i in range(num_tables)
        ]

    def write_hyperplanes(self, filepath: PathLike):
        hyperplanes: dict[str, np.ndarray] = {  # type: ignore
            str(i): table.hyperplanes for i, table in enumerate(self.tables)
        }
        np.savez(filepath, **hyperplanes)

    def __call__(self, array):
        tokens = np.stack([table(array) for table in self.tables], 1)
        return tokens


class LSHBitstream:
    def __init__(
        self,
        n_bits: int,
        dim: int,
        filepath: PathLike | None = None,
        allow_create_hyperplanes: bool = False,  # set this if you want the lsh to allow creation of hyperplanes
    ):
        table_hyperplanes = None
        if filepath is not None:
            filepath = Path(filepath)
            if not filepath.exists():
                raise FileNotFoundError(filepath)
            table_hyperplanes = np.load(filepath)
        elif not allow_create_hyperplanes:
            raise RuntimeError(
                "Not allowed to create hyperplanes but no filepath provided"
            )

        self.table = LSHTable(
            n_bits, dim, table_hyperplanes if table_hyperplanes is not None else None
        )

    def write_hyperplanes(self, filepath: PathLike):
        np.save(filepath, self.table.hyperplanes)

    def __call__(self, array):
        return self.table(array, tokenize=False)