import gradio as gr
import torch
import numpy as np
import h5py
import faiss
from PIL import Image
import io 
import pickle
import random

def get_image(image1, image2, dataset_image_mask, processid_to_index, idx):
    if (idx < 162834):
        image_enc_padded = image1[idx].astype(np.uint8)
    elif(idx >= 162834):
        image_enc_padded = image2[idx-162834].astype(np.uint8)
    enc_length = dataset_image_mask[idx]
    image_enc = image_enc_padded[:enc_length]
    image = Image.open(io.BytesIO(image_enc))
    return image

def searchEmbeddings(id, mod1, mod2):
    # variable and index initialization
    original_indx = processid_to_index[id]    
    dim = 768
    num_neighbors = 10

    # get index
    index = faiss.IndexFlatIP(dim)
    if (mod2 == "Image"):
        index = faiss.read_index("image_index.index")
    elif (mod2 == "DNA"):
        index = faiss.read_index("dna_index.index")

    # search index
    if (mod1 == "Image"):
        query = id_to_image_emb_dict[id]
    elif(mod1 == "DNA"):
        query = id_to_dna_emb_dict[id]
    query = query.astype(np.float32)
    D, I = index.search(query, num_neighbors)

    id_list = []
    for indx in I[0]:
        id = indx_to_id_dict[indx]
        id_list.append(id)
    
    # get images
    image0 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, original_indx)
    image1 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][0])
    image2 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][1])
    image3 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][2])
    image4 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][3])
    image5 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][4])
    image6 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][5])
    image7 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][6])
    image8 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][7])
    image9 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][8])
    image10 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][9])

    # get taxonomic information
    s0 = getTax(original_indx)
    s1 = getTax(I[0][0])
    s2 = getTax(I[0][1])
    s3 = getTax(I[0][2])
    s4 = getTax(I[0][3])
    s5 = getTax(I[0][4])
    s6 = getTax(I[0][5])
    s7 = getTax(I[0][6])
    s8 = getTax(I[0][7])
    s9 = getTax(I[0][8])
    s10 = getTax(I[0][9])
        
    return id_list, image0, image1, image2, image3, image4, image5, image6, image7, image8, image9, image10, s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10

def getRandID():
    indx = random.randrange(0, 325667)
    return indx_to_id_dict[indx], indx

def getTax(indx):
    s = species[indx]
    g = genus[indx]
    f = family[indx]
    str = "Species: " + s + "\nGenus: " + g + "\nFamily: " + f
    return str

with gr.Blocks(title="Bioscan-Clip") as demo:
    # open general files
    with open("dataset_image1.pickle", "rb") as f:
        dataset_image1 = pickle.load(f)
    with open("dataset_image2.pickle", "rb") as f:
        dataset_image2 = pickle.load(f)
    with open("dataset_processid_list.pickle", "rb") as f:
        dataset_processid_list = pickle.load(f)
    with open("dataset_image_mask.pickle", "rb") as f: 
        dataset_image_mask = pickle.load(f)
    with open("processid_to_index.pickle", "rb") as f: 
        processid_to_index = pickle.load(f)
    with open("indx_to_id_dict.pickle", "rb") as f: 
        indx_to_id_dict = pickle.load(f)

    # open image files
    with open("id_to_image_emb_dict.pickle", "rb") as f: 
        id_to_image_emb_dict = pickle.load(f)

    # open dna files
    with open("id_to_dna_emb_dict.pickle", "rb") as f: 
        id_to_dna_emb_dict = pickle.load(f)

    # open taxonomy files
    with open("family.pickle", "rb") as f:
        family = [item.decode("utf-8") for item in pickle.load(f)]
    with open("genus.pickle", "rb") as f:
        genus= [item.decode("utf-8") for item in pickle.load(f)]
    with open("species.pickle", "rb") as f:
        species = [item.decode("utf-8") for item in pickle.load(f)]
    with gr.Column():
        process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
        process_id_list = gr.Textbox(label="Closest 10 matches:" )
        mod1 = gr.Radio(choices=["DNA", "Image"], label="Search From:")
        mod2 = gr.Radio(choices=["DNA", "Image"], label="Search To:")
        search_btn = gr.Button("Search")

        with gr.Row():
            with gr.Column():
                image0 = gr.Image(label="Original", height=550)
                tax0 = gr.Textbox(label="Taxonomy")
            with gr.Column():
                rand_id = gr.Textbox(label="Random ID:")
                rand_id_indx = gr.Textbox(label="Index:")
                id_btn = gr.Button("Get Random ID")

        with gr.Row():
            with gr.Column():
                image1 = gr.Image(label=1)
                tax1 = gr.Textbox(label="Taxonomy")
            with gr.Column():
                image2 = gr.Image(label=2)
                tax2 = gr.Textbox(label="Taxonomy")
            with gr.Column():
                image3 = gr.Image(label=3)
                tax3 = gr.Textbox(label="Taxonomy")

        with gr.Row():
            with gr.Column():
                image4 = gr.Image(label=4)
                tax4 = gr.Textbox(label="Taxonomy")
            with gr.Column():
                image5 = gr.Image(label=5)
                tax5 = gr.Textbox(label="Taxonomy")
            with gr.Column():
                image6 = gr.Image(label=6)
                tax6 = gr.Textbox(label="Taxonomy")

        with gr.Row():   
            with gr.Column():
                image7 = gr.Image(label=7)
                tax7 = gr.Textbox(label="Taxonomy")
            with gr.Column():
                image8 = gr.Image(label=8)
                tax8 = gr.Textbox(label="Taxonomy")
            with gr.Column():
                image9 = gr.Image(label=9)
                tax9 = gr.Textbox(label="Taxonomy")
            with gr.Column():
                image10 = gr.Image(label=10)
                tax10 = gr.Textbox(label="Taxonomy")

    id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx])
    search_btn.click(fn=searchEmbeddings, inputs=[process_id, mod1, mod2], 
                     outputs=[process_id_list, image0, image1, image2, image3, image4, image5, image6, image7, image8, image9, image10, 
                              tax0, tax1, tax2, tax3, tax4, tax5, tax6, tax7, tax8, tax9, tax10])
    examples = gr.Examples(
        examples=[["ABOTH966-22", "DNA", "DNA"],
                  ["CRTOB8472-22", "DNA", "Image"],
                  ["PLOAD050-20", "Image", "DNA"],
                  ["HELAC26711-21", "Image", "Image"]],
        inputs=[process_id, mod1, mod2],)

demo.launch()