# Importing all the necessary libraries
import os

import gradio as gr
import torch
from PIL import Image
from tqdm import tqdm
from trclip.trclip import Trclip
from trclip.visualizer import image_retrieval_visualize, text_retrieval_visualize

print(f'gr version : {gr.__version__}')
import pickle
import random

import numpy as np

# %%
model_name = 'trclip-vitl14-e10'
if not os.path.exists(model_name):
    os.system(f'git clone  https://huggingface.co/yusufani/{model_name}  --progress')
# %%
if not os.path.exists('TrCaption-trclip-vitl14-e10'):
    os.system(f'git clone  https://huggingface.co/datasets/yusufani/TrCaption-trclip-vitl14-e10/ --progress')
    os.chdir('TrCaption-trclip-vitl14-e10')
    os.system(f'git lfs install')
    os.system(f' git lfs fetch')
    os.system(f'  git lfs pull')
    os.chdir('..')


# %%

def load_image_embeddings(load_batch=True):
    path = os.path.join('TrCaption-trclip-vitl14-e10', 'image_embeddings')
    bs = 100_000
    if load_batch:
        for i in tqdm(range(0, 3_100_000, bs), desc='Loading TrCaption Image embeddings'):
            with open(os.path.join(path, f'image_em_{i}.pkl'), 'rb') as f:
                yield pickle.load(f)
        return

    else:
        embeddings = []
        for i in tqdm(range(0, 3_100_000, bs), desc='Loading TrCaption Image embeddings'):
            with open(os.path.join(path, f'image_em_{i}.pkl'), 'rb') as f:
                embeddings.append(pickle.load(f))
        return torch.cat(embeddings, dim=0)


def load_text_embeddings(load_batch=True):
    path = os.path.join('TrCaption-trclip-vitl14-e10', 'text_embeddings')
    bs = 100_000
    if load_batch:
        for i in tqdm(range(0, 3_600_000, bs), desc='Loading TrCaption text embeddings'):
            with open(os.path.join(path, f'text_em_{i}.pkl'), 'rb') as f:
                yield pickle.load(f)
        return
    else:
        embeddings = []
        for i in tqdm(range(0, 3_600_000, bs), desc='Loading TrCaption text embeddings'):
            with open(os.path.join(path, f'text_em_{i}.pkl'), 'rb') as f:
                embeddings.append(pickle.load(f))
        return torch.cat(embeddings, dim=0)


def load_metadata():
    path = os.path.join('TrCaption-trclip-vitl14-e10', 'metadata.pkl')
    with open(path, 'rb') as f:
        metadata = pickle.load(f)
    trcap_texts = metadata['texts']
    trcap_urls = metadata['image_urls']
    return trcap_texts, trcap_urls


def load_spesific_tensor(index, type, bs=100_000):
    part = index // bs
    idx = index % bs
    with open(os.path.join('TrCaption-trclip-vitl14-e10', f'{type}_embeddings', f'{type}_em_{part * bs}.pkl'), 'rb') as f:
        embeddings = pickle.load(f)
    return embeddings[idx]


# %%
trcap_texts, trcap_urls = load_metadata()
# %%
print(f'INFO : Model loading')
model_path = os.path.join(model_name, 'pytorch_model.bin')
trclip = Trclip(model_path, clip_model='ViT-L/14', device='cpu')
# %%


import datetime

# %%
def run_im(im1, use_trcap_images, text1, use_trcap_texts):
    print(f'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} INFO :  Image retrieval starting')
    f_texts_embeddings = None
    ims = None
    if use_trcap_images:
        print('INFO : TRCaption images used')
        im_paths = trcap_urls
    else:
        print('INFO : Own images used')
        # Images taken from user
        im_paths = [i.name for i in im1]
        ims = [Image.open(i) for i in im_paths]
    if use_trcap_texts:
        print(f'INFO : TRCaption texts used')
        random_indexes = random.sample(range(len(trcap_texts)), 2)  # MAX 2 text are allowed in image retrieval UI limit
        f_texts_embeddings = []
        for i in random_indexes:
            f_texts_embeddings.append(load_spesific_tensor(i, 'text'))
        f_texts_embeddings = torch.stack(f_texts_embeddings)
        texts = [trcap_texts[i] for i in random_indexes]

    else:
        print(f'INFO : Own texts used')
        texts = [i.strip() for i in text1.split('\n')[:2] if i.strip() != '']

    if use_trcap_images:  # This means that we will iterate over batches because Huggingface space has 16 gb limit :///
        per_mode_probs = []
        f_texts_embeddings = f_texts_embeddings if use_trcap_texts else trclip.get_text_features(texts)
        for f_image_embeddings in tqdm(load_image_embeddings(load_batch=True), desc='Running image retrieval'):
            batch_probs = trclip.get_results(
                text_features=f_texts_embeddings, image_features=f_image_embeddings, mode='per_text', return_probs=True)
            per_mode_probs.append(batch_probs)
        per_mode_probs = torch.cat(per_mode_probs, dim=1)
        per_mode_probs = per_mode_probs.softmax(dim=-1).cpu().detach().numpy()
        per_mode_indices = [np.argsort(prob)[::-1] for prob in per_mode_probs]

    else:
        per_mode_indices, per_mode_probs = trclip.get_results(texts=texts, images=ims, text_features=f_texts_embeddings, mode='per_text')

    print(f'per_mode_indices = {per_mode_indices}\n,per_mode_probs = {per_mode_probs}  ')
    print(f'im_paths    = {im_paths}')
    return image_retrieval_visualize(per_mode_indices, per_mode_probs, texts, im_paths,
                                     n_figure_in_column=2,
                                     n_images_in_figure=4, n_figure_in_row=1, save_fig=False,
                                     show=False,
                                     break_on_index=-1)


def run_text(im1, use_trcap_images, text1, use_trcap_texts):
    print(f'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} INFO :  Text retrieval starting')
    f_image_embeddings = None
    ims = None
    if use_trcap_images:
        print('INFO : TRCaption images used')
        random_indexes = random.sample(range(len(trcap_urls)), 2)  # MAX 2 text are allowed in image retrieval UI limit
        f_image_embeddings = []
        for i in random_indexes:
            f_image_embeddings.append(load_spesific_tensor(i, 'image'))
        f_image_embeddings = torch.stack(f_image_embeddings)
        print(f'f_image_embeddings = {f_image_embeddings}')
        # Images taken from TRCAPTION
        im_paths = [trcap_urls[i] for i in random_indexes]
        print(f'im_paths = {im_paths}')

    else:
        print('INFO : Own images used')
        # Images taken from user
        im_paths = [i.name for i in im1[:2]]
        ims = [Image.open(i) for i in im_paths]

    if use_trcap_texts:
        texts = trcap_texts
    else:
        texts = [i.strip() for i in text1.split('\n')[:2] if i.strip() != '']

    if use_trcap_texts:
        f_image_embeddings = f_image_embeddings if use_trcap_images else trclip.get_image_features(ims)
        per_mode_probs = []
        for f_texts_embeddings in tqdm(load_text_embeddings(load_batch=True), desc='Running text retrieval'):
            batch_probs = trclip.get_results(
                 text_features=f_texts_embeddings, image_features=f_image_embeddings, mode='per_image', return_probs=True)
            per_mode_probs.append(batch_probs)
        per_mode_probs = torch.cat(per_mode_probs, dim=1)
        per_mode_probs = per_mode_probs.softmax(dim=-1).cpu().detach().numpy()
        per_mode_indices = [np.argsort(prob)[::-1] for prob in per_mode_probs]

    else:
        per_mode_indices, per_mode_probs = trclip.get_results(texts=texts, images=ims, image_features=f_image_embeddings, mode='per_image')
    print(per_mode_indices)
    print(per_mode_probs)
    return text_retrieval_visualize(per_mode_indices, per_mode_probs, im_paths, texts,
                                    n_figure_in_column=4,
                                    n_texts_in_figure=4 if len(texts) > 4 else len(texts),
                                    n_figure_in_row=2,
                                    save_fig=False,
                                    show=False,
                                    break_on_index=-1,
                                    )


def change_textbox(choice):
    if choice == "Use Own Images":

        return gr.Image.update(visible=True)
    else:
        return gr.Image.update(visible=False)


with gr.Blocks() as demo:
    gr.HTML("""
            <div style="text-align: center; max-width: 650px; margin: 0 auto;">
              <div
                style="
                  display: inline-flex;
                  align-items: center;
                  gap: 0.8rem;
                  font-size: 1.75rem;
                "
              >
                <svg
                  width="0.65em"
                  height="0.65em"
                  viewBox="0 0 115 115"
                  fill="none"
                  xmlns="http://www.w3.org/2000/svg"
                >
                  <rect width="23" height="23" fill="white"></rect>
                  <rect y="69" width="23" height="23" fill="white"></rect>
                  <rect x="23" width="23" height="23" fill="#AEAEAE"></rect>
                  <rect x="23" y="69" width="23" height="23" fill="#AEAEAE"></rect>
                  <rect x="46" width="23" height="23" fill="white"></rect>
                  <rect x="46" y="69" width="23" height="23" fill="white"></rect>
                  <rect x="69" width="23" height="23" fill="black"></rect>
                  <rect x="69" y="69" width="23" height="23" fill="black"></rect>
                  <rect x="92" width="23" height="23" fill="#D9D9D9"></rect>
                  <rect x="92" y="69" width="23" height="23" fill="#AEAEAE"></rect>
                  <rect x="115" y="46" width="23" height="23" fill="white"></rect>
                  <rect x="115" y="115" width="23" height="23" fill="white"></rect>
                  <rect x="115" y="69" width="23" height="23" fill="#D9D9D9"></rect>
                  <rect x="92" y="46" width="23" height="23" fill="#AEAEAE"></rect>
                  <rect x="92" y="115" width="23" height="23" fill="#AEAEAE"></rect>
                  <rect x="92" y="69" width="23" height="23" fill="white"></rect>
                  <rect x="69" y="46" width="23" height="23" fill="white"></rect>
                  <rect x="69" y="115" width="23" height="23" fill="white"></rect>
                  <rect x="69" y="69" width="23" height="23" fill="#D9D9D9"></rect>
                  <rect x="46" y="46" width="23" height="23" fill="black"></rect>
                  <rect x="46" y="115" width="23" height="23" fill="black"></rect>
                  <rect x="46" y="69" width="23" height="23" fill="black"></rect>
                  <rect x="23" y="46" width="23" height="23" fill="#D9D9D9"></rect>
                  <rect x="23" y="115" width="23" height="23" fill="#AEAEAE"></rect>
                  <rect x="23" y="69" width="23" height="23" fill="black"></rect>
                </svg>
                <h1 style="font-weight: 1500; margin-bottom: 7px;">
                  Trclip Demo
                                  <a
                  href="https://github.com/yusufani/TrCLIP"
                  style="text-decoration: underline;"
                  target="_blank"
                  ></a
                  Github Trclip: 
                </h1>
              </div>
              <p style="margin-bottom: 10px; font-size: 94%">
                Trclip is Turkish port of real clip. In this space you can try your images or/and texts. 
                <br>Also you can use pre calculated TrCaption embeddings. 
                <br>Number of texts  = 3533312
                <br>Number of images =  3070976
                <br>
                Some images are not available in the internet because I downloaded and calculated TrCaption embeddings long time ago. Don't be suprise if you encounter with Image not found :D
                
                <div style="text-align: center;font-size: 100%">
                <p><strong><span style="background-color: #000000; color: #ffffff;"><a style="background-color: #000000; color: #ffffff;" href="https://github.com/yusufani/TrCLIP">A GitHub Repository</a> </span>--- <span style="background-color: #000000;"><span style="color: #ffffff;">Paper( Not available yet )&nbsp;</span></span></strong></p>
                </div>                
              </p>

            </div>
            <div style="text-align: center; margin: 0 auto;">
                 <p style="margin-bottom: 10px; font-size: 75%" ><em>Huggingface Space containers has 16 gb ram. TrCaption embeddings are totaly 20 gb. </em><em>I did a lot of writing and reading to files to make this space workable. That's why<span style="background-color: #ff6600; color: #ffffff;"> <strong>it's running much slower if you're using TrCaption Embeddig</strong>s</span>.</em></p>
                <div class="sc-jSFjdj sc-iCoGMd jcTaHb kMthTr">
                    <div class="sc-iqAclL xfxEN">
                        <div class="sc-bdnxRM fJdnBK sc-crzoAE DykGo">
                          <div class="sc-gtsrHT gfuSqG">&nbsp;</div>
                        </div>
                    </div>
                </div>
            </div>
        """)

    with gr.Tabs():
        with gr.TabItem("Upload a Images"):
            im_input = gr.components.File(label="Image input", optional=True, file_count='multiple')
    is_trcap_ims = gr.Checkbox(label="Use TRCaption Images\n[Note: Random 2 sample selected in text retrieval mode]",default=True)

    with gr.Tabs():
        with gr.TabItem("Input a text (Seperated by new line Max 2 for Image retrieval)"):
            text_input = gr.components.Textbox(label="Text input", optional=True , placeholder = "kedi\nköpek\nGemi\nKahvesini içmekte olan bir adam\n Kahvesini içmekte olan bir kadın\nAraba")
    is_trcap_texts = gr.Checkbox(label="Use TrCaption Captions \n[Note: Random 2 sample selected in image retrieval mode]",default=True)

    im_ret_but = gr.Button("Image Retrieval")
    text_ret_but = gr.Button("Text Retrieval")

    im_out = gr.components.Image()

    im_ret_but.click(run_im, inputs=[im_input, is_trcap_ims, text_input, is_trcap_texts], outputs=im_out)
    text_ret_but.click(run_text, inputs=[im_input, is_trcap_ims, text_input, is_trcap_texts], outputs=im_out)

demo.launch()

# %%