Spaces:
Running
Running
File size: 4,519 Bytes
07356cd effda21 07356cd effda21 db66d62 b383c02 db66d62 1f788b3 db66d62 b383c02 1f788b3 effda21 db66d62 b383c02 1f788b3 effda21 db66d62 1f788b3 db66d62 51e3825 1f788b3 51e3825 1f788b3 db66d62 51e3825 1f788b3 51e3825 1f788b3 db66d62 1f788b3 db66d62 1f788b3 b383c02 1f788b3 db66d62 b383c02 db66d62 effda21 1f788b3 db66d62 effda21 db66d62 73ff1bb b383c02 1f788b3 db66d62 effda21 b383c02 db66d62 1f788b3 b383c02 db66d62 b383c02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import gradio as gr
import torch
import numpy as np
import h5py
import faiss
from PIL import Image
import io
import pickle
import random
import click
def getRandID():
indx = random.randrange(0, len(index_to_id_dict))
return index_to_id_dict[indx], indx
def get_image_index(indexType):
if indexType == "FlatIP(default)":
return image_index_IP
elif indexType == "FlatL2":
raise NotImplementedError
return image_index_L2
elif indexType == "HNSWFlat":
raise NotImplementedError
return image_index_HNSW
elif indexType == "IVFFlat":
raise NotImplementedError
return image_index_IVF
elif indexType == "LSH":
raise NotImplementedError
return image_index_LSH
def get_dna_index(indexType):
if indexType == "FlatIP(default)":
return dna_index_IP
elif indexType == "FlatL2":
raise NotImplementedError
return dna_index_L2
elif indexType == "HNSWFlat":
raise NotImplementedError
return dna_index_HNSW
elif indexType == "IVFFlat":
raise NotImplementedError
return dna_index_IVF
elif indexType == "LSH":
raise NotImplementedError
return dna_index_LSH
def searchEmbeddings(id, key_type, query_type, index_type, num_results: int = 10):
image_index = get_image_index(index_type)
dna_index = get_dna_index(index_type)
# get index
if query_type == "Image":
query = image_index.reconstruct(id_to_index_dict[id])
elif query_type == "DNA":
query = dna_index.reconstruct(id_to_index_dict[id])
else:
raise ValueError(f"Invalid query type: {query_type}")
query = query.astype(np.float32)
query = np.expand_dims(query, axis=0)
# search for query
if key_type == "Image":
index = image_index
elif key_type == "DNA":
index = dna_index
else:
raise ValueError(f"Invalid key type: {key_type}")
_, I = index.search(query, num_results)
closest_ids = []
for indx in I[0]:
id = index_to_id_dict[indx]
closest_ids.append(id)
return closest_ids
with gr.Blocks() as demo:
# for hf: change all file paths, indx_to_id_dict as well
# load indexes
image_index_IP = faiss.read_index("bioscan_5m_image_IndexFlatIP.index")
# image_index_L2 = faiss.read_index("big_image_index_FlatL2.index")
# image_index_HNSW = faiss.read_index("big_image_index_HNSWFlat.index")
# image_index_IVF = faiss.read_index("big_image_index_IVFFlat.index")
# image_index_LSH = faiss.read_index("big_image_index_LSH.index")
dna_index_IP = faiss.read_index("bioscan_5m_dna_IndexFlatIP.index")
# dna_index_L2 = faiss.read_index("big_dna_index_FlatL2.index")
# dna_index_HNSW = faiss.read_index("big_dna_index_HNSWFlat.index")
# dna_index_IVF = faiss.read_index("big_dna_index_IVFFlat.index")
# dna_index_LSH = faiss.read_index("big_dna_index_LSH.index")
# with open("dataset_processid_list.pickle", "rb") as f:
# dataset_processid_list = pickle.load(f)
# with open("processid_to_index.pickle", "rb") as f:
# processid_to_index = pickle.load(f)
with open("big_indx_to_id_dict.pickle", "rb") as f:
index_to_id_dict = pickle.load(f)
id_to_index_dict = {v: k for k, v in index_to_id_dict.items()}
with gr.Column():
with gr.Row():
with gr.Column():
rand_id = gr.Textbox(label="Random ID:")
rand_id_indx = gr.Textbox(label="Index:")
id_btn = gr.Button("Get Random ID")
with gr.Column():
key_type = gr.Radio(choices=["Image", "DNA"], label="Search From:", value="Image")
query_type = gr.Radio(choices=["Image", "DNA"], label="Search To:", value="Image")
index_type = gr.Radio(
choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)"
)
num_results = gr.Number(label="Number of Results:", value=10, precision=0)
process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
process_id_list = gr.Textbox(label="Closest 10 matches:")
search_btn = gr.Button("Search")
id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx])
search_btn.click(
fn=searchEmbeddings,
inputs=[process_id, key_type, query_type, index_type, num_results],
outputs=[process_id_list],
)
demo.launch()
|