Spaces:
Running
Running
from pathlib import Path | |
import click | |
import faiss | |
import h5py | |
ALL_KEY_TYPES = ["dna", "image"] | |
ALL_INDEX_TYPES = ["IndexFlatIP", "IndexFlatL2", "IndexIVFFlat", "IndexHNSWFlat", "IndexLSH"] | |
EMBEDDING_SIZE = 768 | |
def process(input: Path, output: Path, key_type: str, index_type: str): | |
# load embeddings | |
all_keys = h5py.File(input / "extracted_features_of_all_keys.hdf5", "r", libver="latest")[ | |
f"encoded_{key_type}_feature" | |
][:] | |
seen_test = h5py.File(input / "extracted_features_of_seen_test.hdf5", "r", libver="latest")[ | |
f"encoded_{key_type}_feature" | |
][:] | |
unseen_test = h5py.File(input / "extracted_features_of_unseen_test.hdf5", "r", libver="latest")[ | |
f"encoded_{key_type}_feature" | |
][:] | |
seen_val = h5py.File(input / "extracted_features_of_seen_val.hdf5", "r", libver="latest")[ | |
f"encoded_{key_type}_feature" | |
][:] | |
unseen_val = h5py.File(input / "extracted_features_of_unseen_val.hdf5", "r", libver="latest")[ | |
f"encoded_{key_type}_feature" | |
][:] | |
# FlatIP and FlatL2 | |
if index_type == "IndexFlatIP": | |
test_index = faiss.IndexFlatIP(EMBEDDING_SIZE) | |
elif index_type == "IndexFlatL2": | |
test_index = faiss.IndexFlatL2(EMBEDDING_SIZE) | |
elif index_type == "IndexIVFFlat": | |
# IVFFlat | |
quantizer = faiss.IndexFlatIP(EMBEDDING_SIZE) | |
test_index = faiss.IndexIVFFlat(quantizer, EMBEDDING_SIZE, 128) | |
test_index.train(all_keys) | |
test_index.train(seen_test) | |
test_index.train(unseen_test) | |
test_index.train(seen_val) | |
test_index.train(unseen_val) | |
elif index_type == "IndexHNSWFlat": | |
# HNSW | |
# 16: connections for each vertex. efSearch: depth of search during search. efConstruction: depth of search during build | |
test_index = faiss.IndexHNSWFlat(EMBEDDING_SIZE, 16) | |
test_index.hnsw.efSearch = 32 | |
test_index.hnsw.efConstruction = 64 | |
elif index_type == "IndexLSH": | |
# LSH | |
test_index = faiss.IndexLSH(EMBEDDING_SIZE, EMBEDDING_SIZE * 2) | |
else: | |
raise ValueError(f"Index type {index_type} is not supported") | |
test_index.add(all_keys) | |
test_index.add(seen_test) | |
test_index.add(unseen_test) | |
test_index.add(seen_val) | |
test_index.add(unseen_val) | |
faiss.write_index(test_index, str(output / f"bioscan_5m_{key_type}_{index_type}.index")) | |
print("Saved index to", output / f"bioscan_5m_{key_type}_{index_type}.index") | |
def main(input, output, key_type, index_type): | |
output.mkdir(parents=True, exist_ok=True) | |
if key_type == "all": | |
key_types = ALL_KEY_TYPES | |
else: | |
key_types = [key_type] | |
if index_type == "all": | |
index_types = ALL_INDEX_TYPES | |
else: | |
index_types = [index_type] | |
for key_type in key_types: | |
for index_type in index_types: | |
process(input, output, key_type, index_type) | |
if __name__ == "__main__": | |
main() | |