#coding=utf-8
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_
from contextlib import suppress
import logging
from einops import rearrange
from peft import LoraConfig, get_peft_model
from bigmodelvis import Visualization

from .clip_encoder_hd import CLIPVisionTowerHD
from .conversation import get_conv_template
from .processors_conv import preprocess_qwen
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
from transformers.generation import GenerationConfig
from transformers import Qwen2Config, Qwen2ForCausalLM


def get_autocast(precision, cache_enabled=True):
    if precision == "amp_bfloat16" or precision == "amp_bf16" or precision == 'bf16':
        # amp_bfloat16 is more stable than amp float16 for clip training
        return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16, cache_enabled=cache_enabled)
    elif precision == 'fp16':
        return lambda: torch.cuda.amp.autocast(dtype=torch.float16, cache_enabled=cache_enabled)
    elif precision == 'fp32':
        return suppress
    else:
        raise ValueError('not supported precision: {}'.format(precision))


class LayerNorm(nn.LayerNorm):
    """Subclass torch's LayerNorm to handle fp16."""
    def forward(self, x: torch.Tensor):
        orig_type = x.dtype
        ret = super().forward(x.type(torch.float32))
        return ret.type(orig_type)


class MLP(nn.Module):
    """ Very simple multi-layer perceptron (also called FFN)"""

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x


class InfMLLM_Unified_HD_Chat(PreTrainedModel):
    
    def __init__(self, config, debug=False):
        super().__init__(config)

        ## Initialize LM model
        self.lm_tokenizer = AutoTokenizer.from_pretrained(config._name_or_path, use_fast=False, trust_remote_code=True)
        self.media_token_img = "<|image|>"
        self.media_token_id_img = self.lm_tokenizer(self.media_token_img, return_tensors="pt",add_special_tokens=False).input_ids.item()
        self.lm_model = Qwen2ForCausalLM(config.lm_config)

        self.lm_tokenizer.model_max_length = config.max_txt_len
        
        self.template_name = config.conv_style
        self.preprocess_function = preprocess_qwen

        self.separate = nn.Parameter(torch.zeros([1, 1, 4096]))
        self.newline = nn.Parameter(torch.zeros([1, 1, 1, 4096]))

        ## Initialize image encoder
        self.encoder_img = CLIPVisionTowerHD(config.vision_config, vision_select_layer=-2)
        self.encoder_img_ln = lambda x: x

        self.adapter_img = nn.Sequential(
            nn.Linear(self.encoder_img.num_features*4, self.lm_model.config.hidden_size),
            nn.GELU(),
            nn.Linear(self.lm_model.config.hidden_size, self.lm_model.config.hidden_size)
        )

        ## Others
        self.config = config
        self.precision = config.precision
        self._apply_lemmatizer = getattr(config, 'apply_lemmatizer', False)
        self._lemmatizer = None
        

    def forward_encoder_img(self, image):
        autocast = get_autocast(self.precision, cache_enabled=True)
        with autocast():
            assert isinstance(image, list)
            image_embeds, image_split = self.encoder_img(image, self.separate, self.newline)

            image_embeds = self.encoder_img_ln(image_embeds)                 # [bsz, L, D]
            image_embeds = self.adapter_img(image_embeds)
            return image_embeds, image_split

    def _concat_embeds(self,
                       prompt_embeds, prompt_ids, prompt_masks,
                       labels=None, padding='left'):
        emb_lens = [len(emb) for emb in prompt_embeds]
        if len(set(emb_lens)) == 1:
            if labels is not None:
                return torch.stack(prompt_embeds, dim=0), torch.stack(prompt_ids, dim=0), torch.stack(prompt_masks, dim=0), torch.stack(labels, dim=0)
            return torch.stack(prompt_embeds, dim=0), torch.stack(prompt_ids, dim=0), torch.stack(prompt_masks, dim=0)


        pad_emb = self.lm_model.get_input_embeddings()(torch.tensor(self.lm_tokenizer.pad_token_id, device=prompt_embeds[0].device))

        prompt_embeds_new = pad_emb.expand(len(emb_lens), max(emb_lens), -1).clone()
        prompt_ids_new = torch.ones([len(emb_lens), max(emb_lens)]).to(prompt_ids[0]) * self.lm_tokenizer.pad_token_id
        prompt_masks_new = torch.zeros([len(emb_lens), max(emb_lens)]).to(prompt_masks[0])
        if labels is not None:
            labels_new = -100 * torch.ones([len(emb_lens), max(emb_lens)]).to(prompt_ids[0])

        for i, L in enumerate(emb_lens):
            if padding == 'left':
                prompt_embeds_new[i, -L:] = prompt_embeds[i]
                prompt_ids_new[i, -L:] = prompt_ids[i]
                prompt_masks_new[i, -L:] = prompt_masks[i]
                if labels is not None:
                    labels_new[i, -L:] = labels[i]

            elif padding == 'right':
                prompt_embeds_new[i, :L] = prompt_embeds[i]
                prompt_ids_new[i, :L] = prompt_ids[i]
                prompt_masks_new[i, :L] = prompt_masks[i]
                if labels is not None:
                    labels_new[i, :L] = labels[i]
            else:
                raise ValueError()

        if labels is not None:
            return prompt_embeds_new, prompt_ids_new, prompt_masks_new, labels_new
        return prompt_embeds_new, prompt_ids_new, prompt_masks_new

    def _insert_media_feat(self,
                           prompt_embeds, prompt_ids, prompt_masks,
                           is_languages,
                           embeds_media, media_token_id,
                           index_list=None,
                           labels=None, len_media=None):
        ## insert embeds_media into prompt
        prompt_embeds_new = []
        prompt_masks_new = []
        prompt_ids_new = []
        labels_new = []
        device = embeds_media[0].device

        if index_list is not None:
            assert len(index_list) == len(embeds_media)
            assert len(embeds_media) <= len(prompt_embeds)

        for b in range(len(prompt_embeds)):
            if (index_list is not None) and (b not in index_list):
                prompt_embeds_new.append(prompt_embeds[b])
                prompt_ids_new.append(prompt_ids[b])
                prompt_masks_new.append(prompt_masks[b])
                if labels is not None:
                    labels_new.append(labels[b])
            else:
                _idx = prompt_ids[b].tolist().index(media_token_id)
                if index_list is not None:
                    b_media = index_list.index(b)
                else:
                    b_media = b

                if len_media is not None:
                    cur_embeds_media = embeds_media[b_media, :len_media[b_media]]
                else:
                    cur_embeds_media = embeds_media[b_media]

                prompt_embeds_new.append(torch.cat([prompt_embeds[b][:_idx+1],
                                                    cur_embeds_media,
                                                    prompt_embeds[b][_idx+1:]
                                                    ], dim=0))
                prompt_ids_new.append(torch.cat([prompt_ids[b][:_idx+1],
                                                    torch.ones(len(cur_embeds_media), dtype=torch.long).to(device).fill_(-100),
                                                    prompt_ids[b][_idx+1:]
                                                    ], dim=0))
                if labels is not None:
                    labels_new.append(torch.cat([labels[b][:_idx+1],
                                                    torch.ones(len(cur_embeds_media), dtype=torch.long).to(device).fill_(-100),
                                                    labels[b][_idx+1:]
                                                    ], dim=0))

                # if is pure-language sample, mask out image-embeddings
                prompt_masks_new.append(torch.cat([prompt_masks[b][:_idx+1],
                                                    torch.zeros(len(cur_embeds_media), dtype=torch.long).to(device) if is_languages[b] else
                                                        torch.ones(len(cur_embeds_media), dtype=torch.long).to(device),
                                                    prompt_masks[b][_idx+1:]], dim=0))

        if labels is not None:
            return prompt_embeds_new, prompt_ids_new, prompt_masks_new, labels_new
        return prompt_embeds_new, prompt_ids_new, prompt_masks_new


    @torch.no_grad()
    def generate(
        self,
        samples,
        num_beams=5,
        max_length=128,
        min_length=1,
        top_p=0.9,
        temperature=0.,
        return_prompts=False
    ):
        autocast = get_autocast(self.precision, cache_enabled=True)
        with autocast():
            conversations = samples['conversations']
            is_languages = [False] * len(conversations)

            image_img = samples.get('images', None)
 
            index_img = list(range(len(image_img)))

            device = None
            special_prefix = ["" for _ in range(len(conversations))]

            if (self.config.encoder_img is not None) and (image_img is not None) and len(index_img) > 0:
                for i in index_img:
                    special_prefix[i] = self.media_token_img + special_prefix[i]

                new_image_img = []
                for index in index_img:
                    new_image_img.append(image_img[index])
                embeds_img, len_img = self.forward_encoder_img(new_image_img)
                device = embeds_img.device

            conv = get_conv_template(self.template_name)
            roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}

            prompts = []
            for i, source in enumerate(conversations):
                if roles[source[0]['from']] != conv.roles[0]:
                    # Skip the first one if it is not from human
                    source = source[1:]

                per_prefix = special_prefix[i]
                conv.messages = []
                for j, sentence in enumerate(source):
                    role = roles[sentence['from']]
                    assert role == conv.roles[j % 2], f'{i}'
                    sentence['value'] = sentence['value'].replace("<image>", "").strip()       # llava-1.5 add <image> to the begin of the question, remove here

                    if j == 0:
                        sentence['value'] = per_prefix + sentence['value']

                    conv.append_message(role, sentence['value'])
                prompts.append(conv.get_prompt())

            self.lm_tokenizer.padding_side = "left"
            if self.lm_tokenizer.bos_token is not None:
                prompt_text = [self.lm_tokenizer.bos_token + t for t in prompts]
            else:
                prompt_text = prompts

            prompt_tokens = self.lm_tokenizer(
                prompt_text,
                return_tensors="pt",
                padding="longest",
                truncation=False,
                add_special_tokens=False
            ).to(device)


            prompt_embeds = self.lm_model.get_input_embeddings()(prompt_tokens.input_ids)
            
            prompt_masks = prompt_tokens.attention_mask                                                                         # [bsz, n2]
            prompt_ids = prompt_tokens.input_ids
            assert torch.all(prompt_ids[:, -1] != self.lm_tokenizer.pad_token_id), "make sure padding left"

            if embeds_img is not None:
                prompt_embeds, prompt_ids, prompt_masks = self._insert_media_feat(prompt_embeds=prompt_embeds,
                                                                                  prompt_ids=prompt_ids,
                                                                                  prompt_masks=prompt_masks,
                                                                                  is_languages=is_languages,
                                                                                  embeds_media=embeds_img,
                                                                                  media_token_id=self.media_token_id_img,
                                                                                  index_list=index_img,
                                                                                  len_media=len_img)


            # pad and concat embeds
            prompt_embeds, prompt_ids, prompt_masks = self._concat_embeds(prompt_embeds, prompt_ids, prompt_masks, padding="left")
            assert torch.all(prompt_ids[:, -1] != self.lm_tokenizer.pad_token_id), "make sure padding left"

            kwargs = {}
            kwargs['max_new_tokens'] = max_length

            outputs = self.lm_model.generate(
                #input_ids=input_ids,
                inputs_embeds=prompt_embeds,
                attention_mask=prompt_masks,
                do_sample=True if temperature > 0 else False,
                temperature=temperature,
                top_p=top_p,
                num_beams=num_beams,
                eos_token_id=self.lm_tokenizer.eos_token_id,
                #max_length=max_length,
                min_length=min_length,
                **kwargs
            )
            output_text = self.lm_tokenizer.batch_decode(
                outputs, skip_special_tokens=True
            )
            output_text = [text.strip() for text in output_text]

        if self._apply_lemmatizer or ("apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]):
            output_text = self._lemmatize(output_text)

        if return_prompts:
            return output_text, prompts
        return output_text

    def _lemmatize(self, answers):
        def apply(answer):
            doc = self.lemmatizer(answer)

            words = []
            for token in doc:
                if token.pos_ in ["NOUN", "VERB"]:
                    words.append(token.lemma_)
                else:
                    words.append(token.text)
            answer = " ".join(words)

            return answer

        return [apply(answer) for answer in answers]

    @property
    def lemmatizer(self):
        if self._lemmatizer is None:
            try:
                import spacy
                self._lemmatizer = spacy.load("en_core_web_sm")
            except ImportError:
                logging.error(
                    """
                    Please install spacy and en_core_web_sm model to apply lemmatization.
                    python -m spacy download en_core_web_sm
                    OR
                    import spacy.cli
                    spacy.cli.download("en_core_web_sm")
                    """
                )
                exit(1)

        return self._lemmatizer