Spaces:
Running
Running
File size: 3,540 Bytes
db66d62 effda21 db66d62 effda21 db66d62 effda21 db66d62 effda21 db66d62 effda21 db66d62 effda21 db66d62 effda21 db66d62 effda21 db66d62 effda21 db66d62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
from pathlib import Path
import click
import faiss
import h5py
ALL_KEY_TYPES = ["dna", "image"]
ALL_INDEX_TYPES = ["IndexFlatIP", "IndexFlatL2", "IndexIVFFlat", "IndexHNSWFlat", "IndexLSH"]
EMBEDDING_SIZE = 768
def process(input: Path, output: Path, key_type: str, index_type: str):
# load embeddings
all_keys = h5py.File(input / "extracted_features_of_all_keys.hdf5", "r", libver="latest")[
f"encoded_{key_type}_feature"
][:]
seen_test = h5py.File(input / "extracted_features_of_seen_test.hdf5", "r", libver="latest")[
f"encoded_{key_type}_feature"
][:]
unseen_test = h5py.File(input / "extracted_features_of_unseen_test.hdf5", "r", libver="latest")[
f"encoded_{key_type}_feature"
][:]
seen_val = h5py.File(input / "extracted_features_of_seen_val.hdf5", "r", libver="latest")[
f"encoded_{key_type}_feature"
][:]
unseen_val = h5py.File(input / "extracted_features_of_unseen_val.hdf5", "r", libver="latest")[
f"encoded_{key_type}_feature"
][:]
# FlatIP and FlatL2
if index_type == "IndexFlatIP":
test_index = faiss.IndexFlatIP(EMBEDDING_SIZE)
elif index_type == "IndexFlatL2":
test_index = faiss.IndexFlatL2(EMBEDDING_SIZE)
elif index_type == "IndexIVFFlat":
# IVFFlat
quantizer = faiss.IndexFlatIP(EMBEDDING_SIZE)
test_index = faiss.IndexIVFFlat(quantizer, EMBEDDING_SIZE, 128)
test_index.train(all_keys)
test_index.train(seen_test)
test_index.train(unseen_test)
test_index.train(seen_val)
test_index.train(unseen_val)
elif index_type == "IndexHNSWFlat":
# HNSW
# 16: connections for each vertex. efSearch: depth of search during search. efConstruction: depth of search during build
test_index = faiss.IndexHNSWFlat(EMBEDDING_SIZE, 16)
test_index.hnsw.efSearch = 32
test_index.hnsw.efConstruction = 64
elif index_type == "IndexLSH":
# LSH
test_index = faiss.IndexLSH(EMBEDDING_SIZE, EMBEDDING_SIZE * 2)
else:
raise ValueError(f"Index type {index_type} is not supported")
test_index.add(all_keys)
test_index.add(seen_test)
test_index.add(unseen_test)
test_index.add(seen_val)
test_index.add(unseen_val)
faiss.write_index(test_index, str(output / f"bioscan_5m_{key_type}_{index_type}.index"))
print("Saved index to", output / f"bioscan_5m_{key_type}_{index_type}.index")
@click.command()
@click.option(
"--input",
type=click.Path(path_type=Path),
default="bioscan-clip-scripts/extracted_features",
help="Path to extracted features",
)
@click.option(
"--output", type=click.Path(path_type=Path), default="bioscan-clip-scripts/index", help="Path to save the index"
)
@click.option(
"--key-type", "key_type", type=click.Choice(["all", *ALL_KEY_TYPES]), default="all", help="Type of key to use"
)
@click.option(
"--index-type",
"index_type",
type=click.Choice(["all", *ALL_INDEX_TYPES]),
default="all",
help="Type of index to use",
)
def main(input, output, key_type, index_type):
output.mkdir(parents=True, exist_ok=True)
if key_type == "all":
key_types = ALL_KEY_TYPES
else:
key_types = [key_type]
if index_type == "all":
index_types = ALL_INDEX_TYPES
else:
index_types = [index_type]
for key_type in key_types:
for index_type in index_types:
process(input, output, key_type, index_type)
if __name__ == "__main__":
main()
|