File size: 3,540 Bytes
db66d62
 
 
 
 
 
 
 
 
 
 
effda21
db66d62
effda21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db66d62
 
 
effda21
db66d62
effda21
db66d62
 
effda21
 
 
 
 
 
 
db66d62
 
 
effda21
db66d62
 
 
 
effda21
db66d62
 
 
effda21
 
 
 
 
db66d62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
effda21
db66d62
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from pathlib import Path

import click
import faiss
import h5py

ALL_KEY_TYPES = ["dna", "image"]
ALL_INDEX_TYPES = ["IndexFlatIP", "IndexFlatL2", "IndexIVFFlat", "IndexHNSWFlat", "IndexLSH"]
EMBEDDING_SIZE = 768


def process(input: Path, output: Path, key_type: str, index_type: str):
    # load embeddings
    all_keys = h5py.File(input / "extracted_features_of_all_keys.hdf5", "r", libver="latest")[
        f"encoded_{key_type}_feature"
    ][:]
    seen_test = h5py.File(input / "extracted_features_of_seen_test.hdf5", "r", libver="latest")[
        f"encoded_{key_type}_feature"
    ][:]
    unseen_test = h5py.File(input / "extracted_features_of_unseen_test.hdf5", "r", libver="latest")[
        f"encoded_{key_type}_feature"
    ][:]
    seen_val = h5py.File(input / "extracted_features_of_seen_val.hdf5", "r", libver="latest")[
        f"encoded_{key_type}_feature"
    ][:]
    unseen_val = h5py.File(input / "extracted_features_of_unseen_val.hdf5", "r", libver="latest")[
        f"encoded_{key_type}_feature"
    ][:]

    # FlatIP and FlatL2
    if index_type == "IndexFlatIP":
        test_index = faiss.IndexFlatIP(EMBEDDING_SIZE)
    elif index_type == "IndexFlatL2":
        test_index = faiss.IndexFlatL2(EMBEDDING_SIZE)
    elif index_type == "IndexIVFFlat":
        # IVFFlat
        quantizer = faiss.IndexFlatIP(EMBEDDING_SIZE)
        test_index = faiss.IndexIVFFlat(quantizer, EMBEDDING_SIZE, 128)
        test_index.train(all_keys)
        test_index.train(seen_test)
        test_index.train(unseen_test)
        test_index.train(seen_val)
        test_index.train(unseen_val)
    elif index_type == "IndexHNSWFlat":
        # HNSW
        # 16: connections for each vertex. efSearch: depth of search during search. efConstruction: depth of search during build
        test_index = faiss.IndexHNSWFlat(EMBEDDING_SIZE, 16)
        test_index.hnsw.efSearch = 32
        test_index.hnsw.efConstruction = 64
    elif index_type == "IndexLSH":
        # LSH
        test_index = faiss.IndexLSH(EMBEDDING_SIZE, EMBEDDING_SIZE * 2)
    else:
        raise ValueError(f"Index type {index_type} is not supported")

    test_index.add(all_keys)
    test_index.add(seen_test)
    test_index.add(unseen_test)
    test_index.add(seen_val)
    test_index.add(unseen_val)

    faiss.write_index(test_index, str(output / f"bioscan_5m_{key_type}_{index_type}.index"))
    print("Saved index to", output / f"bioscan_5m_{key_type}_{index_type}.index")


@click.command()
@click.option(
    "--input",
    type=click.Path(path_type=Path),
    default="bioscan-clip-scripts/extracted_features",
    help="Path to extracted features",
)
@click.option(
    "--output", type=click.Path(path_type=Path), default="bioscan-clip-scripts/index", help="Path to save the index"
)
@click.option(
    "--key-type", "key_type", type=click.Choice(["all", *ALL_KEY_TYPES]), default="all", help="Type of key to use"
)
@click.option(
    "--index-type",
    "index_type",
    type=click.Choice(["all", *ALL_INDEX_TYPES]),
    default="all",
    help="Type of index to use",
)
def main(input, output, key_type, index_type):
    output.mkdir(parents=True, exist_ok=True)

    if key_type == "all":
        key_types = ALL_KEY_TYPES
    else:
        key_types = [key_type]

    if index_type == "all":
        index_types = ALL_INDEX_TYPES
    else:
        index_types = [index_type]

    for key_type in key_types:
        for index_type in index_types:
            process(input, output, key_type, index_type)


if __name__ == "__main__":
    main()