from pathlib import Path import click import faiss import h5py ALL_KEY_TYPES = ["dna", "image"] ALL_INDEX_TYPES = ["IndexFlatIP", "IndexFlatL2", "IndexIVFFlat", "IndexHNSWFlat", "IndexLSH"] EMBEDDING_SIZE = 768 def process(input: Path, output: Path, key_type: str, index_type: str): # load embeddings all_keys = h5py.File(input / "extracted_features_of_all_keys.hdf5", "r", libver="latest")[ f"encoded_{key_type}_feature" ][:] seen_test = h5py.File(input / "extracted_features_of_seen_test.hdf5", "r", libver="latest")[ f"encoded_{key_type}_feature" ][:] unseen_test = h5py.File(input / "extracted_features_of_unseen_test.hdf5", "r", libver="latest")[ f"encoded_{key_type}_feature" ][:] seen_val = h5py.File(input / "extracted_features_of_seen_val.hdf5", "r", libver="latest")[ f"encoded_{key_type}_feature" ][:] unseen_val = h5py.File(input / "extracted_features_of_unseen_val.hdf5", "r", libver="latest")[ f"encoded_{key_type}_feature" ][:] # FlatIP and FlatL2 if index_type == "IndexFlatIP": test_index = faiss.IndexFlatIP(EMBEDDING_SIZE) elif index_type == "IndexFlatL2": test_index = faiss.IndexFlatL2(EMBEDDING_SIZE) elif index_type == "IndexIVFFlat": # IVFFlat quantizer = faiss.IndexFlatIP(EMBEDDING_SIZE) test_index = faiss.IndexIVFFlat(quantizer, EMBEDDING_SIZE, 128) test_index.train(all_keys) test_index.train(seen_test) test_index.train(unseen_test) test_index.train(seen_val) test_index.train(unseen_val) elif index_type == "IndexHNSWFlat": # HNSW # 16: connections for each vertex. efSearch: depth of search during search. efConstruction: depth of search during build test_index = faiss.IndexHNSWFlat(EMBEDDING_SIZE, 16) test_index.hnsw.efSearch = 32 test_index.hnsw.efConstruction = 64 elif index_type == "IndexLSH": # LSH test_index = faiss.IndexLSH(EMBEDDING_SIZE, EMBEDDING_SIZE * 2) else: raise ValueError(f"Index type {index_type} is not supported") test_index.add(all_keys) test_index.add(seen_test) test_index.add(unseen_test) test_index.add(seen_val) test_index.add(unseen_val) faiss.write_index(test_index, str(output / f"bioscan_5m_{key_type}_{index_type}.index")) print("Saved index to", output / f"bioscan_5m_{key_type}_{index_type}.index") @click.command() @click.option( "--input", type=click.Path(path_type=Path), default="bioscan-clip-scripts/extracted_features", help="Path to extracted features", ) @click.option( "--output", type=click.Path(path_type=Path), default="bioscan-clip-scripts/index", help="Path to save the index" ) @click.option( "--key-type", "key_type", type=click.Choice(["all", *ALL_KEY_TYPES]), default="all", help="Type of key to use" ) @click.option( "--index-type", "index_type", type=click.Choice(["all", *ALL_INDEX_TYPES]), default="all", help="Type of index to use", ) def main(input, output, key_type, index_type): output.mkdir(parents=True, exist_ok=True) if key_type == "all": key_types = ALL_KEY_TYPES else: key_types = [key_type] if index_type == "all": index_types = ALL_INDEX_TYPES else: index_types = [index_type] for key_type in key_types: for index_type in index_types: process(input, output, key_type, index_type) if __name__ == "__main__": main()