Spaces:
Runtime error
Runtime error
File size: 4,019 Bytes
89816ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import gradio as gr
import torch
import numpy as np
import h5py
import faiss
from PIL import Image
import io
import pickle
import random
def getRandID():
indx = random.randrange(0, 396503)
return indx_to_id_dict[indx], indx
def chooseImageIndex(indexType):
if (indexType == "FlatIP(default)"):
return image_index_IP
elif (indexType == "FlatL2"):
return image_index_L2
elif (indexType == "HNSWFlat"):
return image_index_HNSW
elif (indexType == "IVFFlat"):
return image_index_IVF
elif (indexType == "LSH"):
return image_index_LSH
def chooseDNAIndex(indexType):
if (indexType == "FlatIP(default)"):
return dna_index_IP
elif (indexType == "FlatL2"):
return dna_index_L2
elif (indexType == "HNSWFlat"):
return dna_index_HNSW
elif (indexType == "IVFFlat"):
return dna_index_IVF
elif (indexType == "LSH"):
return dna_index_LSH
def searchEmbeddings(id, mod1, mod2, indexType):
# variable and index initialization
dim = 768
count = 0
num_neighbors = 10
index = faiss.IndexFlatIP(dim)
# get index
if (mod2 == "Image"):
index = chooseImageIndex(indexType)
elif (mod2 == "DNA"):
index = chooseDNAIndex(indexType)
# search for query
if (mod1 == "Image"):
query = id_to_image_emb_dict[id]
elif (mod1 == "DNA"):
query = id_to_dna_emb_dict[id]
query = query.astype(np.float32)
D, I = index.search(query, num_neighbors)
id_list = []
i = 1
for indx in I[0]:
id = indx_to_id_dict[indx]
id_list.append(id)
return id_list
with gr.Blocks() as demo:
# for hf: change all file paths, indx_to_id_dict as well
# load indexes
image_index_IP = faiss.read_index("big_image_index_FlatIP.index")
image_index_L2 = faiss.read_index("big_image_index_FlatL2.index")
image_index_HNSW = faiss.read_index("big_image_index_HNSWFlat.index")
image_index_IVF = faiss.read_index("big_image_index_IVFFlat.index")
image_index_LSH = faiss.read_index("big_image_index_LSH.index")
dna_index_IP = faiss.read_index("big_dna_index_FlatIP.index")
dna_index_L2 = faiss.read_index("big_dna_index_FlatL2.index")
dna_index_HNSW = faiss.read_index("big_dna_index_HNSWFlat.index")
dna_index_IVF = faiss.read_index("big_dna_index_IVFFlat.index")
dna_index_LSH = faiss.read_index("big_dna_index_LSH.index")
with open("dataset_processid_list.pickle", "rb") as f:
dataset_processid_list = pickle.load(f)
with open("processid_to_index.pickle", "rb") as f:
processid_to_index = pickle.load(f)
with open("big_indx_to_id_dict.pickle", "rb") as f:
indx_to_id_dict = pickle.load(f)
# initialize both possible dicts
with open("big_id_to_image_emb_dict.pickle", "rb") as f:
id_to_image_emb_dict = pickle.load(f)
with open("big_id_to_dna_emb_dict.pickle", "rb") as f:
id_to_dna_emb_dict = pickle.load(f)
with gr.Column():
with gr.Row():
with gr.Column():
rand_id = gr.Textbox(label="Random ID:")
rand_id_indx = gr.Textbox(label="Index:")
id_btn = gr.Button("Get Random ID")
with gr.Column():
mod1 = gr.Radio(choices=["DNA", "Image"], label="Search From:")
mod2 = gr.Radio(choices=["DNA", "Image"], label="Search To:")
indexType = gr.Radio(choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)")
process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
process_id_list = gr.Textbox(label="Closest 10 matches:" )
search_btn = gr.Button("Search")
id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx])
search_btn.click(fn=searchEmbeddings, inputs=[process_id, mod1, mod2, indexType],
outputs=[process_id_list])
demo.launch() |