import torch
from tqdm import tqdm
import random
from llava_utils import prompt_wrapper, generator
from torchvision.utils import save_image
import numpy as np
from copy import deepcopy
import time

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import MultiCursor
import seaborn as sns


class Attacker:

    def __init__(self, args, model, tokenizer, targets, device='cuda:0'):

        self.args = args
        self.model = model
        self.device = device
        self.tokenizer = tokenizer
        self.tokenizer.padding_side = "right"

        self.targets = targets # targets that we want to promte likelihood
        self.loss_buffer = []
        self.num_targets = len(self.targets)

        # freeze and set to eval model:
        self.model.eval()
        self.model.requires_grad_(False)

    def get_vocabulary(self):

        vocab_dicts = self.tokenizer.get_vocab()
        vocabs = vocab_dicts.keys()

        single_token_vocabs = []
        single_token_vocabs_embedding = []
        single_token_id_to_vocab = dict()
        single_token_vocab_to_id = dict()

        cnt = 0

        for item in vocabs:
            tokens = self.tokenizer(item, return_tensors="pt", add_special_tokens=False).input_ids.to(self.device)
            if tokens.shape[1] == 1:

                single_token_vocabs.append(item)
                emb = self.model.model.embed_tokens(tokens)
                single_token_vocabs_embedding.append(emb)

                single_token_id_to_vocab[cnt] = item
                single_token_vocab_to_id[item] = cnt

                cnt+=1

        single_token_vocabs_embedding = torch.cat(single_token_vocabs_embedding, dim=1).squeeze()

        self.vocabs = single_token_vocabs
        self.embedding_matrix = single_token_vocabs_embedding.to(self.device)
        self.id_to_vocab = single_token_id_to_vocab
        self.vocab_to_id = single_token_vocab_to_id

    def hotflip_attack(self, grad, token,
                       increase_loss=False, num_candidates=1):

        token_id = self.vocab_to_id[token]
        token_emb = self.embedding_matrix[token_id] # embedding of current token

        scores = ((self.embedding_matrix - token_emb) @ grad.T).squeeze(1)

        if not increase_loss:
            scores *= -1  # lower versus increase the class probability.

        _, best_k_ids = torch.topk(scores, num_candidates)
        return best_k_ids.detach().cpu().numpy()

    def wrap_prompt_simple(self, text_prompt_template, adv_prompt, batch_size):

        text_prompts = text_prompt_template + ' ' + adv_prompt # insert the adversarial prompt

        prompt = prompt_wrapper.Prompt(model=self.model, tokenizer = self.tokenizer, text_prompts=[text_prompts])

        prompt.context_embs[0] = prompt.context_embs[0].detach().requires_grad_(True)
        prompt.context_embs = prompt.context_embs * batch_size

        return prompt

    def update_adv_prompt(self, adv_prompt_tokens, idx, new_token):
        next_adv_prompt_tokens = deepcopy(adv_prompt_tokens)
        next_adv_prompt_tokens[idx] = new_token
        next_adv_prompt = ' '.join(next_adv_prompt_tokens)
        return next_adv_prompt_tokens, next_adv_prompt



    def attack(self, text_prompt_template, offset, batch_size = 8, num_iter=2000):

        print('>>> batch_size: ', batch_size)
        
        my_generator = generator.Generator(model=self.model, tokenizer=self.tokenizer)

        self.get_vocabulary()
        vocabs, embedding_matrix = self.vocabs, self.embedding_matrix

        trigger_token_length = 16 # equivalent to
        adv_prompt_tokens = random.sample(vocabs, trigger_token_length)
        adv_prompt = ' '.join(adv_prompt_tokens)
        print(len(vocabs),adv_prompt)

        st = time.time()

        for t in tqdm(range(num_iter+1)):

            for token_to_flip in range(0, trigger_token_length): # for each token in the trigger

                batch_targets = random.sample(self.targets, batch_size)
                prompt = self.wrap_prompt_simple(text_prompt_template, adv_prompt, batch_size)
                if t==0 and token_to_flip==0:
                    print(prompt.text_prompts)

                target_loss = -self.attack_loss(prompt, batch_targets)
                loss = target_loss # to minimize
                loss.backward()

                print('[adv_prompt]', adv_prompt)
                print("target_loss: %f" % (target_loss.item()))
                self.loss_buffer.append(target_loss.item())

                tokens_grad = prompt.context_embs[0].grad[:, token_to_flip+offset, :]
                candidates = self.hotflip_attack(tokens_grad, adv_prompt_tokens[token_to_flip],
                                            increase_loss=False, num_candidates=self.args.n_candidates)

                self.model.zero_grad()

                # try all the candidates and pick the best
                # comparing candidates does not require gradient computation
                with torch.no_grad():
                    curr_best_loss = 999999
                    curr_best_trigger_tokens = None
                    curr_best_trigger = None

                    for cand in candidates:
                        next_adv_prompt_tokens, next_adv_prompt = self.update_adv_prompt(adv_prompt_tokens,
                                                                    token_to_flip, self.id_to_vocab[cand])
                        prompt = self.wrap_prompt_simple(text_prompt_template, next_adv_prompt, batch_size)

                        next_target_loss = -self.attack_loss(prompt, batch_targets)
                        curr_loss = next_target_loss  # to minimize

                        if curr_loss < curr_best_loss:
                            curr_best_loss = curr_loss
                            curr_best_trigger_tokens = next_adv_prompt_tokens
                            curr_best_trigger = next_adv_prompt

                    # Update overall best if the best current candidate is better
                    if curr_best_loss < loss:
                        adv_prompt_tokens = curr_best_trigger_tokens
                        adv_prompt = curr_best_trigger
                print('(update: %f minutes)' % ((time.time() - st) / 60))

            self.plot_loss()

            print('######### Output - Iter = %d ##########' % t)
                
        return adv_prompt

    def plot_loss(self):

        sns.set_theme()

        num_iters = len(self.loss_buffer)

        num_iters = min(num_iters, 5000)

        x_ticks = list(range(0, num_iters))

        # Plot and label the training and validation loss values
        plt.plot(x_ticks, self.loss_buffer[:num_iters], label='Target Loss')

        # Add in a title and axes labels
        plt.title('Loss Plot')
        plt.xlabel('Iters')
        plt.ylabel('Loss')

        # Display the plot
        plt.legend(loc='best')
        plt.savefig('%s/loss_curve.png' % (self.args.save_dir))
        plt.clf()

        torch.save(self.loss_buffer, '%s/loss' % (self.args.save_dir))


    def attack_loss(self, prompts, targets):

        context_embs = prompts.context_embs
        assert len(context_embs) == len(targets), "Unmathced batch size of prompts and targets, the length of context_embs is %d, the length of targets is %d" % (len(context_embs), len(targets))

        batch_size = len(targets)

        to_regress_tokens = self.tokenizer(
            targets,
            return_tensors="pt",
            padding="longest",
            truncation=True,
            max_length=160,
            add_special_tokens=False
        ).to(self.device)
        to_regress_embs = self.model.model.embed_tokens(to_regress_tokens.input_ids)

        bos = torch.ones([1, 1],
                         dtype=to_regress_tokens.input_ids.dtype,
                         device=to_regress_tokens.input_ids.device) * self.tokenizer.bos_token_id
        bos_embs = self.model.model.embed_tokens(bos)

        pad = torch.ones([1, 1],
                         dtype=to_regress_tokens.input_ids.dtype,
                         device=to_regress_tokens.input_ids.device) * self.tokenizer.pad_token_id
        pad_embs = self.model.model.embed_tokens(pad)


        T = to_regress_tokens.input_ids.masked_fill(
            to_regress_tokens.input_ids == self.tokenizer.pad_token_id, -100
        )

        pos_padding = torch.argmin(T, dim=1) # a simple trick to find the start position of padding

        input_embs = []
        targets_mask = []

        target_tokens_length = []
        context_tokens_length = []
        seq_tokens_length = []

        for i in range(batch_size):

            pos = int(pos_padding[i])
            if T[i][pos] == -100:
                target_length = pos
            else:
                target_length = T.shape[1]

            targets_mask.append(T[i:i+1, :target_length])
            input_embs.append(to_regress_embs[i:i+1, :target_length]) # omit the padding tokens

            context_length = context_embs[i].shape[1]
            seq_length = target_length + context_length

            target_tokens_length.append(target_length)
            context_tokens_length.append(context_length)
            seq_tokens_length.append(seq_length)

        max_length = max(seq_tokens_length)

        attention_mask = []

        for i in range(batch_size):

            # masked out the context from loss computation
            context_mask =(
                torch.ones([1, context_tokens_length[i] + 1],
                       dtype=torch.long).to(self.device).fill_(-100)  # plus one for bos
            )

            # padding to align the length
            num_to_pad = max_length - seq_tokens_length[i]
            padding_mask = (
                torch.ones([1, num_to_pad],
                       dtype=torch.long).to(self.device).fill_(-100)
            )

            targets_mask[i] = torch.cat( [context_mask, targets_mask[i], padding_mask], dim=1 )
            input_embs[i] = torch.cat( [bos_embs, context_embs[i], input_embs[i],
                                        pad_embs.repeat(1, num_to_pad, 1)], dim=1 )
            attention_mask.append( torch.LongTensor( [[1]* (1+seq_tokens_length[i]) + [0]*num_to_pad ] ) )

        targets = torch.cat( targets_mask, dim=0 ).to(self.device)
        inputs_embs = torch.cat( input_embs, dim=0 ).to(self.device)
        attention_mask = torch.cat(attention_mask, dim=0).to(self.device)


        outputs = self.model.model(
                inputs_embeds=inputs_embs,
                attention_mask=attention_mask,
        )

        hidden_states = outputs[0]
        logits = self.model.lm_head(hidden_states)
        loss = None
        # Shift so that tokens < n predict n
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = targets[..., 1:].contiguous()
        # Flatten the tokens
        from torch.nn import CrossEntropyLoss
        loss_fct = CrossEntropyLoss()
        shift_logits = shift_logits.view(-1, self.model.vocab_size)
        shift_labels = shift_labels.view(-1)
        # Enable model/pipeline parallelism
        shift_labels = shift_labels.to(self.device)
        loss = loss_fct(shift_logits, shift_labels)

        return loss