# -*- coding: utf-8 -*-
"""message_bottle.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1I47sLakpuwERGzn-XoNct67mwiDS1mQD
"""

import matplotlib.pyplot as plt
import matplotlib

import argparse
import glob
import logging
import os
import pickle
import random


import torch
import torch.nn.functional as F
import numpy as np

from tqdm import tqdm, trange
from types import SimpleNamespace

import sys
sys.path.append('/home/ryn_mote/Misc/generative_recommender/text_space/Optimus/code/examples/big_ae/')
sys.path.append('/home/ryn_mote/Misc/generative_recommender/text_space/Optimus/code/')
from pytorch_transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, BertConfig
from pytorch_transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2ForLatentConnector
from pytorch_transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
from pytorch_transformers import XLNetLMHeadModel, XLNetTokenizer
from pytorch_transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
from pytorch_transformers import BertForLatentConnector, BertTokenizer

from modules import VAE

import torch
import torch.nn as nn
import torch.nn.functional as F
torch.set_float32_matmul_precision('high')

from tqdm import tqdm

################################################



def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    assert logits.dim() == 1  # batch size 1 for now - could be updated for more but the code would be less clear
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value
    return logits

def sample_sequence_conditional(model, length, context, past=None, num_samples=1, temperature=1, top_k=0, top_p=0.0, device='cpu', decoder_tokenizer=None):
    
    context = torch.tensor(context, dtype=torch.long, device=device)
    context = context.unsqueeze(0).repeat(num_samples, 1)
    generated = context
    with torch.no_grad():
        while True:
        # for _ in trange(length):
            inputs = {'input_ids': generated, 'past': past}
            outputs = model(**inputs)  # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
            next_token_logits = outputs[0][0, -1, :] / temperature
            filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
            next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
            generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)

            # pdb.set_trace()
            if next_token.unsqueeze(0)[0,0].item() == decoder_tokenizer.encode('<EOS>')[0]:
                break

    return generated


def latent_code_from_text(text,):# args):
    tokenized1 = tokenizer_encoder.encode(text)
    tokenized1 = [101] + tokenized1 + [102]
    coded1 = torch.Tensor([tokenized1])
    coded1 =torch.Tensor.long(coded1)
    with torch.no_grad():
        x0 = coded1
        x0 = x0.to('cuda')
        pooled_hidden_fea = model_vae.encoder(x0, attention_mask=(x0 > 0).float())[1]
        mean, logvar = model_vae.encoder.linear(pooled_hidden_fea).chunk(2, -1)
        latent_z = mean.squeeze(1)  
        coded_length = len(tokenized1)
        return latent_z, coded_length

# args
def text_from_latent_code(latent_z):
    past = latent_z
    context_tokens = tokenizer_decoder.encode('<BOS>')

    length = 128 # maximum length, but not used 
    out = sample_sequence_conditional(
        model=model_vae.decoder,
        context=context_tokens,
        past=past,
        length= length, # Chunyuan: Fix length; or use <EOS> to complete a sentence
        temperature=.2,
        top_k=50,
        top_p=.98,
        device='cuda',
        decoder_tokenizer = tokenizer_decoder
    )
    text_x1 = tokenizer_decoder.decode(out[0,:].tolist(), clean_up_tokenization_spaces=True)
    text_x1 = text_x1.split()[1:-1]
    text_x1 = ' '.join(text_x1)
    return text_x1


################################################
# Load model


MODEL_CLASSES = {
    'gpt2': (GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer),
    'bert': (BertConfig, BertForLatentConnector, BertTokenizer)
}

latent_size = 768
model_path = '/home/ryn_mote/Misc/generative_recommender/text_space/1.0_checkpoint-31250/checkpoint-31250/checkpoint-full-31250/'
encoder_path = '/home/ryn_mote/Misc/generative_recommender/text_space/1.0_checkpoint-31250/checkpoint-31250/checkpoint-encoder-31250/'
decoder_path = '/home/ryn_mote/Misc/generative_recommender/text_space/1.0_checkpoint-31250/checkpoint-31250/checkpoint-decoder-31250/'
block_size = 100

# Load a trained Encoder model and vocabulary that you have fine-tuned
encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES['bert']
model_encoder = encoder_model_class.from_pretrained(encoder_path, latent_size=latent_size)
tokenizer_encoder = encoder_tokenizer_class.from_pretrained('bert-base-cased', do_lower_case=True)

model_encoder.to('cuda')
if block_size <= 0:
    block_size = tokenizer_encoder.max_len_single_sentence  # Our input block size will be the max possible for the model
block_size = min(block_size, tokenizer_encoder.max_len_single_sentence)

# Load a trained Decoder model and vocabulary that you have fine-tuned
decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES['gpt2']
model_decoder = decoder_model_class.from_pretrained(decoder_path, latent_size=latent_size)
tokenizer_decoder = decoder_tokenizer_class.from_pretrained('gpt2', do_lower_case=False)
model_decoder.to('cuda')
if block_size <= 0:
    block_size = tokenizer_decoder.max_len_single_sentence  # Our input block size will be the max possible for the model
block_size = min(block_size, tokenizer_decoder.max_len_single_sentence)

# Load full model
output_full_dir = '/home/ryn_mote/Misc/generative_recommender/text_space/' 
checkpoint = torch.load(os.path.join(model_path, 'training.bin'))

# Chunyuan: Add Padding token to GPT2
special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
print('We have added', num_added_toks, 'tokens to GPT2')
model_decoder.resize_token_embeddings(len(tokenizer_decoder))  # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
assert tokenizer_decoder.pad_token == '<PAD>'


# Evaluation
model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, SimpleNamespace(**{'latent_size': latent_size, 'device':'cuda'}))
model_vae.load_state_dict(checkpoint['model_state_dict'])
print("Pre-trained Optimus is successfully loaded")
model_vae.to('cuda').to(torch.bfloat16)

l = latent_code_from_text('A photo of a mountain.')[0]
t = text_from_latent_code(l)
print(t, l, l.shape)
################################################

import gradio as gr
import numpy as np
from sklearn.svm import SVC
from sklearn.inspection import permutation_importance
from sklearn import preprocessing
import pandas as pd
import random
import time

 
dtype = torch.bfloat16
torch.set_grad_enabled(False)

prompt_list = [p for p in list(set(
                pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]

start_time = time.time()

####################### Setup Model

# TODO put back
# @spaces.GPU()
def generate(prompt, in_embs=None,):
  if prompt != '':
    print(prompt)
    #in_embs = in_embs / in_embs.abs().max() * .15 if in_embs != None else None
    in_embs = .9 * in_embs.to('cuda') + .5 * latent_code_from_text(prompt)[0] if in_embs != None else latent_code_from_text(prompt)[0]
  else:
    print('From embeds.')
  in_embs = in_embs / in_embs.abs().max() * .6
  in_embs = in_embs.to('cuda').to(torch.bfloat16)
  plt.close('all')
  plt.hist(np.array(in_embs.detach().to('cpu').to(torch.float)).flatten(), bins=5)
  plt.savefig('real_im_emb_plot.jpg')
    
  
  text = text_from_latent_code(in_embs)
  in_embs = latent_code_from_text(text)[0]
  print(text)
  return text, in_embs.to('cpu')


#######################

# TODO add to state instead of shared across all
glob_idx = 0

def next_one(embs, ys, calibrate_prompts):
    global glob_idx
    glob_idx = glob_idx + 1

    with torch.no_grad():
        if len(calibrate_prompts) > 0:
            print('######### Calibrating with sample prompts #########')
            prompt = calibrate_prompts.pop(0)
            text, img_embs = generate(prompt)
            embs += img_embs
            print(len(embs))
            return text, embs, ys, calibrate_prompts
        else:
            print('######### Roaming #########')


            # handle case where every instance of calibration prompts is 'Neither' or 'Like' or 'Dislike'
            if len(list(set(ys))) <= 1:
                embs.append(.01*torch.randn(latent_size))
                embs.append(.01*torch.randn(latent_size))
                ys.append(0)
                ys.append(1)
            if len(list(ys)) < 10:
                embs += [.01*torch.randn(latent_size)] * 3
                ys += [0] * 3

            pos_indices = [i for i in range(len(embs)) if ys[i] == 1]
            neg_indices = [i for i in range(len(embs)) if ys[i] == 0]

            # the embs & ys stay tied by index but we shuffle to drop randomly
            random.shuffle(pos_indices)
            random.shuffle(neg_indices)

            #if len(pos_indices) - len(neg_indices) > 48 and len(pos_indices) > 80:
            #    pos_indices = pos_indices[32:]
            if len(neg_indices) - len(pos_indices) > 48/16 and len(pos_indices) > 6:
                pos_indices = pos_indices[5:]
            if len(neg_indices) - len(pos_indices) > 48/16 and len(neg_indices) > 6:
                neg_indices = neg_indices[5:]


            if len(neg_indices) > 25:
                neg_indices = neg_indices[1:]

            print(len(pos_indices), len(neg_indices))
            indices = pos_indices + neg_indices

            embs = [embs[i] for i in indices]
            ys = [ys[i] for i in indices]


            indices = list(range(len(embs)))

            # also add the latest 0 and the latest 1
            has_0 = False
            has_1 = False
            for i in reversed(range(len(ys))):
                if ys[i] == 0 and has_0 == False:
                    indices.append(i)
                    has_0 = True
                elif ys[i] == 1 and has_1 == False:
                    indices.append(i)
                    has_1 = True
                if has_0 and has_1:
                    break

            # we may have just encountered a rare multi-threading diffusers issue (https://github.com/huggingface/diffusers/issues/5749);
            # this ends up adding a rating but losing an embedding, it seems.
            # let's take off a rating if so to continue without indexing errors.
            if len(ys) > len(embs):
                print('ys are longer than embs; popping latest rating')
                ys.pop(-1)

            feature_embs = np.array(torch.stack([embs[i].to('cpu') for i in indices]).to('cpu'))
            scaler = preprocessing.StandardScaler().fit(feature_embs)
            feature_embs = scaler.transform(feature_embs)
            chosen_y = np.array([ys[i] for i in indices])

            print('Gathering coefficients')
            lin_class = SVC(max_iter=50000, kernel='linear', class_weight='balanced', C=.1).fit(feature_embs, chosen_y)
            coef_ = torch.tensor(lin_class.coef_, dtype=torch.double)
            print(coef_.shape, 'COEF')
            print('Gathered')

            rng_prompt = random.choice(prompt_list)
            w = 1# if len(embs) % 2 == 0 else 0
            im_emb = w * coef_.to(dtype=dtype)

            prompt= '' if glob_idx % 3 != 0 else rng_prompt
            text, im_emb = generate(prompt, im_emb)
            embs += im_emb


            return text, embs, ys, calibrate_prompts









def start(_, embs, ys, calibrate_prompts):
    text, embs, ys, calibrate_prompts = next_one(embs, ys, calibrate_prompts)
    return [
            gr.Button(value='Like (L)', interactive=True),
            gr.Button(value='Neither (Space)', interactive=True),
            gr.Button(value='Dislike (A)', interactive=True),
            gr.Button(value='Start', interactive=False),
            text,
            embs,
            ys,
            calibrate_prompts
            ]


def choose(text, choice, embs, ys, calibrate_prompts):
    if choice == 'Like (L)':
        choice = 1
    elif choice == 'Neither (Space)':
        embs = embs[:-1]
        text, embs, ys, calibrate_prompts = next_one(embs, ys, calibrate_prompts)
        return text, embs, ys, calibrate_prompts
    else:
        choice = 0

    # if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
    # TODO skip allowing rating
    if text == None:
        print('NSFW -- choice is disliked')
        choice = 0

    ys += [choice]*1
    text, embs, ys, calibrate_prompts = next_one(embs, ys, calibrate_prompts)
    return text, embs, ys, calibrate_prompts

css = '''.gradio-container{max-width: 700px !important}
#description{text-align: center}
#description h1, #description h3{display: block}
#description p{margin-top: 0}
.fade-in-out {animation: fadeInOut 3s forwards}
@keyframes fadeInOut {
    0% {
      background: var(--bg-color);
    }
    100% {
      background: var(--button-secondary-background-fill);
    }
}
'''
js_head = '''
<script>
document.addEventListener('keydown', function(event) {
    if (event.key === 'a' || event.key === 'A') {
        // Trigger click on 'dislike' if 'A' is pressed
        document.getElementById('dislike').click();
    } else if (event.key === ' ' || event.keyCode === 32) {
        // Trigger click on 'neither' if Spacebar is pressed
        document.getElementById('neither').click();
    } else if (event.key === 'l' || event.key === 'L') {
        // Trigger click on 'like' if 'L' is pressed
        document.getElementById('like').click();
    }
});
function fadeInOut(button, color) {
  button.style.setProperty('--bg-color', color);
  button.classList.remove('fade-in-out');
  void button.offsetWidth; // This line forces a repaint by accessing a DOM property

  button.classList.add('fade-in-out');
  button.addEventListener('animationend', () => {
    button.classList.remove('fade-in-out'); // Reset the animation state
  }, {once: true});
}
document.body.addEventListener('click', function(event) {
    const target = event.target;
    if (target.id === 'dislike') {
      fadeInOut(target, '#ff1717');
    } else if (target.id === 'like') {
      fadeInOut(target, '#006500');
    } else if (target.id === 'neither') {
      fadeInOut(target, '#cccccc');
    }
});

</script>
'''

with gr.Blocks(css=css, head=js_head) as demo:
    gr.Markdown('''# Compass
### Generative Recommenders for Exporation of Text

Explore the latent space without prompting based on your preferences. Learn more on [the write-up](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/).
    ''', elem_id="description")
    embs = gr.State([])
    ys = gr.State([])
    calibrate_prompts = gr.State([
    'the moon is melting into my glass of tea',
    'a sea slug -- pair of claws scuttling -- jelly fish glowing',
    'an adorable creature. It may be a goblin or a pig or a slug.',
    'an animation about a gorgeous nebula',
    'a sketch of an impressive mountain by da vinci',
    'a watercolor painting: the octopus writhes',
    ])
    def l():
        return None

    with gr.Row(elem_id='output-image'):
        text = gr.Textbox(interactive=False, elem_id="text")
    with gr.Row(equal_height=True):
        b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike")
        b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither")
        b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like")
        b1.click(
        choose,
        [text, b1, embs, ys, calibrate_prompts],
        [text, embs, ys, calibrate_prompts]
        )
        b2.click(
        choose,
        [text, b2, embs, ys, calibrate_prompts],
        [text, embs, ys, calibrate_prompts]
        )
        b3.click(
        choose,
        [text, b3, embs, ys, calibrate_prompts],
        [text, embs, ys, calibrate_prompts]
        )
    with gr.Row():
        b4 = gr.Button(value='Start')
        b4.click(start,
                 [b4, embs, ys, calibrate_prompts],
                 [b1, b2, b3, b4, text, embs, ys, calibrate_prompts])
    with gr.Row():
        html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several prompts and then roam. </ div><br><br><br>
<div style='text-align:center; font-size:14px'>Note that while the model is unlikely to produce NSFW text, this may still occur, and users should avoid NSFW content when rating.
</ div>
<br><br>
<div style='text-align:center; font-size:14px'>Thanks to @multimodalart for their contributions to the demo, esp. the interface and @maxbittker for feedback.
</ div>''')

demo.launch(share=True)