File size: 4,519 Bytes
07356cd
effda21
07356cd
effda21
 
 
 
 
 
 
db66d62
b383c02
db66d62
1f788b3
 
db66d62
b383c02
1f788b3
effda21
 
 
 
 
 
 
 
 
 
 
 
 
 
db66d62
b383c02
1f788b3
effda21
 
 
 
 
 
 
 
 
 
 
 
 
 
db66d62
 
1f788b3
 
 
db66d62
 
51e3825
1f788b3
51e3825
1f788b3
 
 
 
 
db66d62
 
51e3825
1f788b3
51e3825
1f788b3
 
 
db66d62
1f788b3
 
 
db66d62
1f788b3
 
b383c02
1f788b3
db66d62
b383c02
db66d62
 
 
 
 
effda21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f788b3
db66d62
 
 
 
 
 
 
 
effda21
 
db66d62
73ff1bb
b383c02
 
1f788b3
 
db66d62
effda21
b383c02
db66d62
 
1f788b3
 
 
 
 
b383c02
db66d62
b383c02
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import gradio as gr
import torch
import numpy as np
import h5py
import faiss
from PIL import Image
import io
import pickle
import random
import click


def getRandID():
    indx = random.randrange(0, len(index_to_id_dict))
    return index_to_id_dict[indx], indx


def get_image_index(indexType):
    if indexType == "FlatIP(default)":
        return image_index_IP
    elif indexType == "FlatL2":
        raise NotImplementedError
        return image_index_L2
    elif indexType == "HNSWFlat":
        raise NotImplementedError
        return image_index_HNSW
    elif indexType == "IVFFlat":
        raise NotImplementedError
        return image_index_IVF
    elif indexType == "LSH":
        raise NotImplementedError
        return image_index_LSH


def get_dna_index(indexType):
    if indexType == "FlatIP(default)":
        return dna_index_IP
    elif indexType == "FlatL2":
        raise NotImplementedError
        return dna_index_L2
    elif indexType == "HNSWFlat":
        raise NotImplementedError
        return dna_index_HNSW
    elif indexType == "IVFFlat":
        raise NotImplementedError
        return dna_index_IVF
    elif indexType == "LSH":
        raise NotImplementedError
        return dna_index_LSH


def searchEmbeddings(id, key_type, query_type, index_type, num_results: int = 10):
    image_index = get_image_index(index_type)
    dna_index = get_dna_index(index_type)

    # get index
    if query_type == "Image":
        query = image_index.reconstruct(id_to_index_dict[id])
    elif query_type == "DNA":
        query = dna_index.reconstruct(id_to_index_dict[id])
    else:
        raise ValueError(f"Invalid query type: {query_type}")
    query = query.astype(np.float32)
    query = np.expand_dims(query, axis=0)

    # search for query
    if key_type == "Image":
        index = image_index
    elif key_type == "DNA":
        index = dna_index
    else:
        raise ValueError(f"Invalid key type: {key_type}")

    _, I = index.search(query, num_results)

    closest_ids = []
    for indx in I[0]:
        id = index_to_id_dict[indx]
        closest_ids.append(id)

    return closest_ids


with gr.Blocks() as demo:

    # for hf: change all file paths, indx_to_id_dict as well

    # load indexes
    image_index_IP = faiss.read_index("bioscan_5m_image_IndexFlatIP.index")
    # image_index_L2 = faiss.read_index("big_image_index_FlatL2.index")
    # image_index_HNSW = faiss.read_index("big_image_index_HNSWFlat.index")
    # image_index_IVF = faiss.read_index("big_image_index_IVFFlat.index")
    # image_index_LSH = faiss.read_index("big_image_index_LSH.index")

    dna_index_IP = faiss.read_index("bioscan_5m_dna_IndexFlatIP.index")
    # dna_index_L2 = faiss.read_index("big_dna_index_FlatL2.index")
    # dna_index_HNSW = faiss.read_index("big_dna_index_HNSWFlat.index")
    # dna_index_IVF = faiss.read_index("big_dna_index_IVFFlat.index")
    # dna_index_LSH = faiss.read_index("big_dna_index_LSH.index")

    # with open("dataset_processid_list.pickle", "rb") as f:
    #     dataset_processid_list = pickle.load(f)
    # with open("processid_to_index.pickle", "rb") as f:
    #     processid_to_index = pickle.load(f)
    with open("big_indx_to_id_dict.pickle", "rb") as f:
        index_to_id_dict = pickle.load(f)
    id_to_index_dict = {v: k for k, v in index_to_id_dict.items()}

    with gr.Column():
        with gr.Row():
            with gr.Column():
                rand_id = gr.Textbox(label="Random ID:")
                rand_id_indx = gr.Textbox(label="Index:")
                id_btn = gr.Button("Get Random ID")
            with gr.Column():
                key_type = gr.Radio(choices=["Image", "DNA"], label="Search From:", value="Image")
                query_type = gr.Radio(choices=["Image", "DNA"], label="Search To:", value="Image")

        index_type = gr.Radio(
            choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)"
        )
        num_results = gr.Number(label="Number of Results:", value=10, precision=0)

        process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
        process_id_list = gr.Textbox(label="Closest 10 matches:")
        search_btn = gr.Button("Search")
        id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx])

    search_btn.click(
        fn=searchEmbeddings,
        inputs=[process_id, key_type, query_type, index_type, num_results],
        outputs=[process_id_list],
    )


demo.launch()