browser-backend / app.py
atwang's picture
Revert "update code to download dataset files from separate repo"
effda21
import gradio as gr
import torch
import numpy as np
import h5py
import faiss
from PIL import Image
import io
import pickle
import random
import click
def getRandID():
indx = random.randrange(0, len(index_to_id_dict))
return index_to_id_dict[indx], indx
def get_image_index(indexType):
if indexType == "FlatIP(default)":
return image_index_IP
elif indexType == "FlatL2":
raise NotImplementedError
return image_index_L2
elif indexType == "HNSWFlat":
raise NotImplementedError
return image_index_HNSW
elif indexType == "IVFFlat":
raise NotImplementedError
return image_index_IVF
elif indexType == "LSH":
raise NotImplementedError
return image_index_LSH
def get_dna_index(indexType):
if indexType == "FlatIP(default)":
return dna_index_IP
elif indexType == "FlatL2":
raise NotImplementedError
return dna_index_L2
elif indexType == "HNSWFlat":
raise NotImplementedError
return dna_index_HNSW
elif indexType == "IVFFlat":
raise NotImplementedError
return dna_index_IVF
elif indexType == "LSH":
raise NotImplementedError
return dna_index_LSH
def searchEmbeddings(id, key_type, query_type, index_type, num_results: int = 10):
image_index = get_image_index(index_type)
dna_index = get_dna_index(index_type)
# get index
if query_type == "Image":
query = image_index.reconstruct(id_to_index_dict[id])
elif query_type == "DNA":
query = dna_index.reconstruct(id_to_index_dict[id])
else:
raise ValueError(f"Invalid query type: {query_type}")
query = query.astype(np.float32)
query = np.expand_dims(query, axis=0)
# search for query
if key_type == "Image":
index = image_index
elif key_type == "DNA":
index = dna_index
else:
raise ValueError(f"Invalid key type: {key_type}")
_, I = index.search(query, num_results)
closest_ids = []
for indx in I[0]:
id = index_to_id_dict[indx]
closest_ids.append(id)
return closest_ids
with gr.Blocks() as demo:
# for hf: change all file paths, indx_to_id_dict as well
# load indexes
image_index_IP = faiss.read_index("bioscan_5m_image_IndexFlatIP.index")
# image_index_L2 = faiss.read_index("big_image_index_FlatL2.index")
# image_index_HNSW = faiss.read_index("big_image_index_HNSWFlat.index")
# image_index_IVF = faiss.read_index("big_image_index_IVFFlat.index")
# image_index_LSH = faiss.read_index("big_image_index_LSH.index")
dna_index_IP = faiss.read_index("bioscan_5m_dna_IndexFlatIP.index")
# dna_index_L2 = faiss.read_index("big_dna_index_FlatL2.index")
# dna_index_HNSW = faiss.read_index("big_dna_index_HNSWFlat.index")
# dna_index_IVF = faiss.read_index("big_dna_index_IVFFlat.index")
# dna_index_LSH = faiss.read_index("big_dna_index_LSH.index")
# with open("dataset_processid_list.pickle", "rb") as f:
# dataset_processid_list = pickle.load(f)
# with open("processid_to_index.pickle", "rb") as f:
# processid_to_index = pickle.load(f)
with open("big_indx_to_id_dict.pickle", "rb") as f:
index_to_id_dict = pickle.load(f)
id_to_index_dict = {v: k for k, v in index_to_id_dict.items()}
with gr.Column():
with gr.Row():
with gr.Column():
rand_id = gr.Textbox(label="Random ID:")
rand_id_indx = gr.Textbox(label="Index:")
id_btn = gr.Button("Get Random ID")
with gr.Column():
key_type = gr.Radio(choices=["Image", "DNA"], label="Search From:", value="Image")
query_type = gr.Radio(choices=["Image", "DNA"], label="Search To:", value="Image")
index_type = gr.Radio(
choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)"
)
num_results = gr.Number(label="Number of Results:", value=10, precision=0)
process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
process_id_list = gr.Textbox(label="Closest 10 matches:")
search_btn = gr.Button("Search")
id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx])
search_btn.click(
fn=searchEmbeddings,
inputs=[process_id, key_type, query_type, index_type, num_results],
outputs=[process_id_list],
)
demo.launch()