import pickle import random import gradio as gr import numpy as np from data import load_indexes_local, load_indexes_hf, load_index_pickle def getRandID(): indx = random.randrange(0, len(index_to_id_dict)) return index_to_id_dict[indx], indx def get_image_index(indexType): try: return image_indexes[indexType] except KeyError: raise KeyError(f"Tried to load an image index that is not supported: {indexType}") def get_dna_index(indexType): try: return dna_indexes[indexType] except KeyError: raise KeyError(f"Tried to load a DNA index that is not supported: {indexType}") def searchEmbeddings(id, key_type, query_type, index_type, num_results: int = 10): image_index = get_image_index(index_type) dna_index = get_dna_index(index_type) # get index if query_type == "Image": query = image_index.reconstruct(id_to_index_dict[id]) elif query_type == "DNA": query = dna_index.reconstruct(id_to_index_dict[id]) else: raise ValueError(f"Invalid query type: {query_type}") query = query.astype(np.float32) query = np.expand_dims(query, axis=0) # search for query if key_type == "Image": index = image_index elif key_type == "DNA": index = dna_index else: raise ValueError(f"Invalid key type: {key_type}") _, I = index.search(query, num_results) closest_ids = [] for indx in I[0]: id = index_to_id_dict[indx] closest_ids.append(id) return closest_ids with gr.Blocks() as demo: # for hf: change all file paths, indx_to_id_dict as well # load indexes image_indexes = load_indexes_hf( {"FlatIP(default)": "bioscan_5m_image_IndexFlatIP.index"}, repo_name="bioscan-ml/bioscan-clibd" ) dna_indexes = load_indexes_hf( {"FlatIP(default)": "bioscan_5m_dna_IndexFlatIP.index"}, repo_name="bioscan-ml/bioscan-clibd" ) index_to_id_dict = load_index_pickle("big_indx_to_id_dict.pickle", repo_name="bioscan-ml/bioscan-clibd") id_to_index_dict = {v: k for k, v in index_to_id_dict.items()} with gr.Column(): with gr.Row(): with gr.Column(): rand_id = gr.Textbox(label="Random ID:") rand_id_indx = gr.Textbox(label="Index:") id_btn = gr.Button("Get Random ID") with gr.Column(): query_type = gr.Radio(choices=["Image", "DNA"], label="Query:", value="Image") key_type = gr.Radio(choices=["Image", "DNA"], label="Key:", value="Image") index_type = gr.Radio( choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)" ) num_results = gr.Number(label="Number of Results:", value=10, precision=0) process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for") process_id_list = gr.Textbox(label="Closest matches:") search_btn = gr.Button("Search") id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx]) search_btn.click( fn=searchEmbeddings, inputs=[process_id, key_type, query_type, index_type, num_results], outputs=[process_id_list], ) demo.launch()