Spaces:
Running
Running
import pickle | |
import random | |
import gradio as gr | |
import numpy as np | |
from data import load_indexes_local, load_indexes_hf, load_index_pickle | |
def getRandID(): | |
indx = random.randrange(0, len(index_to_id_dict)) | |
return index_to_id_dict[indx], indx | |
def get_image_index(indexType): | |
try: | |
return image_indexes[indexType] | |
except KeyError: | |
raise KeyError(f"Tried to load an image index that is not supported: {indexType}") | |
def get_dna_index(indexType): | |
try: | |
return dna_indexes[indexType] | |
except KeyError: | |
raise KeyError(f"Tried to load a DNA index that is not supported: {indexType}") | |
def searchEmbeddings(id, key_type, query_type, index_type, num_results: int = 10): | |
image_index = get_image_index(index_type) | |
dna_index = get_dna_index(index_type) | |
# get index | |
if query_type == "Image": | |
query = image_index.reconstruct(id_to_index_dict[id]) | |
elif query_type == "DNA": | |
query = dna_index.reconstruct(id_to_index_dict[id]) | |
else: | |
raise ValueError(f"Invalid query type: {query_type}") | |
query = query.astype(np.float32) | |
query = np.expand_dims(query, axis=0) | |
# search for query | |
if key_type == "Image": | |
index = image_index | |
elif key_type == "DNA": | |
index = dna_index | |
else: | |
raise ValueError(f"Invalid key type: {key_type}") | |
_, I = index.search(query, num_results) | |
closest_ids = [] | |
for indx in I[0]: | |
id = index_to_id_dict[indx] | |
closest_ids.append(id) | |
return closest_ids | |
with gr.Blocks() as demo: | |
# for hf: change all file paths, indx_to_id_dict as well | |
# load indexes | |
image_indexes = load_indexes_hf( | |
{"FlatIP(default)": "bioscan_5m_image_IndexFlatIP.index"}, repo_name="bioscan-ml/bioscan-clibd" | |
) | |
dna_indexes = load_indexes_hf( | |
{"FlatIP(default)": "bioscan_5m_dna_IndexFlatIP.index"}, repo_name="bioscan-ml/bioscan-clibd" | |
) | |
index_to_id_dict = load_index_pickle("big_indx_to_id_dict.pickle", repo_name="bioscan-ml/bioscan-clibd") | |
id_to_index_dict = {v: k for k, v in index_to_id_dict.items()} | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
rand_id = gr.Textbox(label="Random ID:") | |
rand_id_indx = gr.Textbox(label="Index:") | |
id_btn = gr.Button("Get Random ID") | |
with gr.Column(): | |
query_type = gr.Radio(choices=["Image", "DNA"], label="Query:", value="Image") | |
key_type = gr.Radio(choices=["Image", "DNA"], label="Key:", value="Image") | |
index_type = gr.Radio( | |
choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)" | |
) | |
num_results = gr.Number(label="Number of Results:", value=10, precision=0) | |
process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for") | |
process_id_list = gr.Textbox(label="Closest matches:") | |
search_btn = gr.Button("Search") | |
id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx]) | |
search_btn.click( | |
fn=searchEmbeddings, | |
inputs=[process_id, key_type, query_type, index_type, num_results], | |
outputs=[process_id_list], | |
) | |
demo.launch() | |