import os
import torch
import torch.nn as nn
import torch.nn.functional as F


import math
from tqdm import tqdm
import argparse
from collections import OrderedDict
import json

from collections import defaultdict
from model.deberta_moe import DebertaV2ForMaskedLM
from transformers import DebertaV2Tokenizer

import clip
import ffmpeg
from VideoLoader import VideoLoader

def get_mask(lengths, max_length):
    """ Computes a batch of padding masks given batched lengths """
    mask = 1 * (
        torch.arange(max_length).unsqueeze(1) < lengths
    ).transpose(0, 1)
    return mask

class Infer:
    def __init__(self, device):
        pretrained_ckpt = torch.load("ckpts/model.pth", map_location="cpu")
        args = pretrained_ckpt['args']
        args.n_ans = 2
        args.max_tokens = 256
        self.args = args
        self.clip_model = clip.load("ViT-L/14", device = device)[0]
        self.tokenizer = DebertaV2Tokenizer.from_pretrained(
            "ckpts/deberta-v2-xlarge", local_files_only=True
        )
        
        self.model = DebertaV2ForMaskedLM.from_pretrained(
                        features_dim=args.features_dim if args.use_video else 0,
                        max_feats=args.max_feats,
                        freeze_lm=args.freeze_lm,
                        freeze_mlm=args.freeze_mlm,
                        ft_ln=args.ft_ln,
                        ds_factor_attn=args.ds_factor_attn,
                        ds_factor_ff=args.ds_factor_ff,
                        dropout=args.dropout,
                        n_ans=args.n_ans,
                        freeze_last=args.freeze_last,
                        pretrained_model_name_or_path="ckpts/deberta-v2-xlarge",
                        local_files_only=False,
                        add_video_feat=args.add_video_feat,
                        freeze_ad=args.freeze_ad,
                    )
        new_state_dict = OrderedDict()
        for k, v in pretrained_ckpt['model'].items():
            new_state_dict[k.replace("module.","")] = v
        self.model.load_state_dict(pretrained_ckpt, strict=False)
        self.model.eval()
        self.model.to(device)
        self.device = device

        self.video_loader = VideoLoader()
        self.set_answer()

    def _get_clip_feature(self, video):
        feat = self.clip_model.encode_image(video.to(self.device))
        #feat = F.normalize(feat, dim=1)
        return feat

    def set_answer(self):
        tok_yes = torch.tensor(
                    self.tokenizer(
                        "Yes",
                        add_special_tokens=False,
                        max_length=1,
                        truncation=True,
                        padding="max_length",
                    )["input_ids"],
                    dtype=torch.long,
                )
        tok_no = torch.tensor(
            self.tokenizer(
                "No",
                add_special_tokens=False,
                max_length=1,
                truncation=True,
                padding="max_length",
            )["input_ids"],
            dtype=torch.long,
        )     

        a2tok = torch.stack([tok_yes, tok_no])
        self.model.set_answer_embeddings(
            a2tok.to(self.model.device), freeze_last=self.args.freeze_last
        )

    def generate(self, text, candidates, video_path):
        video, video_len = self.video_loader(video_path)
        video = self._get_clip_feature(video).unsqueeze(0).float()
        video_mask = get_mask(video_len, 10)
        video_mask = torch.cat([torch.ones((1,1)),video_mask], dim=1)
        logits_list = []
        
        question = text.capitalize().strip()
        if question[-1] != "?":
            question = str(question) + "?"

        for aid in range(len(candidates)):
            prompt = (
                f" Question: {question} Is it '{candidates[aid]}'? {self.tokenizer.mask_token}. Subtitles: "
            )
            prompt = prompt.strip()
            encoded = self.tokenizer(
                prompt,
                add_special_tokens=True,
                max_length=self.args.max_tokens,
                padding="longest",
                truncation=True,
                return_tensors="pt",
            )
            # forward

            output = self.model(
                video=video.to(self.device),
                video_mask=video_mask.to(self.device),
                input_ids=encoded["input_ids"].to(self.device),
                attention_mask=encoded["attention_mask"].to(self.device),
            )
            # += output['loads'].detach().cpu()
            logits = output["logits"]
            # get logits for the mask token
            delay = 11
            logits = logits[:, delay : encoded["input_ids"].size(1) + delay][
                encoded["input_ids"] == self.tokenizer.mask_token_id
            ]
            logits_list.append(logits.softmax(-1)[:, 0])
        
        logits = torch.stack(logits_list, 1)
        if logits.shape[1] == 1:
            preds = logits.round().long().squeeze(1)
        else:
            preds = logits.max(1).indices

        return candidates[preds]