File size: 4,352 Bytes
db66d62
 
 
 
 
 
b383c02
db66d62
 
df89a31
db66d62
b383c02
db66d62
 
 
 
b383c02
db66d62
b383c02
db66d62
b383c02
 
db66d62
b383c02
 
db66d62
b383c02
 
db66d62
b383c02
 
db66d62
 
b383c02
db66d62
b383c02
db66d62
b383c02
 
db66d62
b383c02
 
db66d62
b383c02
 
db66d62
b383c02
 
db66d62
 
 
 
 
 
 
 
 
 
 
 
b383c02
db66d62
b383c02
db66d62
 
 
b383c02
db66d62
b383c02
db66d62
 
 
 
 
 
 
 
 
b383c02
db66d62
 
b383c02
db66d62
 
 
 
 
b383c02
 
 
 
 
 
 
 
 
 
 
db66d62
142b745
 
 
 
db66d62
 
 
 
 
 
df89a31
 
 
db66d62
 
 
 
 
 
 
 
73ff1bb
 
db66d62
73ff1bb
b383c02
 
db66d62
b383c02
 
db66d62
 
73ff1bb
b383c02
db66d62
b383c02
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import gradio as gr
import torch
import numpy as np
import h5py
import faiss
from PIL import Image
import io
import pickle
import random
import click


def getRandID():
    indx = random.randrange(0, 396503)
    return indx_to_id_dict[indx], indx


def chooseImageIndex(indexType):
    if indexType == "FlatIP(default)":
        return image_index_IP
    elif indexType == "FlatL2":
        raise NotImplementedError
        return image_index_L2
    elif indexType == "HNSWFlat":
        raise NotImplementedError
        return image_index_HNSW
    elif indexType == "IVFFlat":
        raise NotImplementedError
        return image_index_IVF
    elif indexType == "LSH":
        raise NotImplementedError
        return image_index_LSH


def chooseDNAIndex(indexType):
    if indexType == "FlatIP(default)":
        return dna_index_IP
    elif indexType == "FlatL2":
        raise NotImplementedError
        return dna_index_L2
    elif indexType == "HNSWFlat":
        raise NotImplementedError
        return dna_index_HNSW
    elif indexType == "IVFFlat":
        raise NotImplementedError
        return dna_index_IVF
    elif indexType == "LSH":
        raise NotImplementedError
        return dna_index_LSH


def searchEmbeddings(id, mod1, mod2, indexType):
    # variable and index initialization
    dim = 768
    count = 0
    num_neighbors = 10

    index = faiss.IndexFlatIP(dim)

    # get index
    if mod2 == "Image":
        index = chooseImageIndex(indexType)
    elif mod2 == "DNA":
        index = chooseDNAIndex(indexType)

    # search for query
    if mod1 == "Image":
        query = id_to_image_emb_dict[id]
    elif mod1 == "DNA":
        query = id_to_dna_emb_dict[id]
    query = query.astype(np.float32)
    D, I = index.search(query, num_neighbors)

    id_list = []
    i = 1
    for indx in I[0]:
        id = indx_to_id_dict[indx]
        id_list.append(id)

    return id_list


with gr.Blocks() as demo:

    # for hf: change all file paths, indx_to_id_dict as well

    # load indexes
    image_index_IP = faiss.read_index("bioscan_5m_image_IndexFlatIP.index")
    # image_index_L2 = faiss.read_index("big_image_index_FlatL2.index")
    # image_index_HNSW = faiss.read_index("big_image_index_HNSWFlat.index")
    # image_index_IVF = faiss.read_index("big_image_index_IVFFlat.index")
    # image_index_LSH = faiss.read_index("big_image_index_LSH.index")

    dna_index_IP = faiss.read_index("bioscan_5m_dna_IndexFlatIP.index")
    # dna_index_L2 = faiss.read_index("big_dna_index_FlatL2.index")
    # dna_index_HNSW = faiss.read_index("big_dna_index_HNSWFlat.index")
    # dna_index_IVF = faiss.read_index("big_dna_index_IVFFlat.index")
    # dna_index_LSH = faiss.read_index("big_dna_index_LSH.index")

    # with open("dataset_processid_list.pickle", "rb") as f:
    #     dataset_processid_list = pickle.load(f)
    # with open("processid_to_index.pickle", "rb") as f:
    #     processid_to_index = pickle.load(f)
    with open("big_indx_to_id_dict.pickle", "rb") as f:
        indx_to_id_dict = pickle.load(f)

    # initialize both possible dicts
    with open("big_id_to_image_emb_dict.pickle", "rb") as f:
        id_to_image_emb_dict = pickle.load(f)
    # with open("big_id_to_dna_emb_dict.pickle", "rb") as f:
    #     id_to_dna_emb_dict = pickle.load(f)
    id_to_dna_emb_dict = None

    with gr.Column():
        with gr.Row():
            with gr.Column():
                rand_id = gr.Textbox(label="Random ID:")
                rand_id_indx = gr.Textbox(label="Index:")
                id_btn = gr.Button("Get Random ID")
            with gr.Column():
                key_type = gr.Radio(choices=["DNA", "Image"], label="Search From:")
                query_type = gr.Radio(choices=["DNA", "Image"], label="Search To:")

        index_type = gr.Radio(
            choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)"
        )
        process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
        process_id_list = gr.Textbox(label="Closest 10 matches:")
        search_btn = gr.Button("Search")
        id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx])

    search_btn.click(fn=searchEmbeddings, inputs=[process_id, key_type, query_type, index_type], outputs=[process_id_list])


demo.launch()