import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
from fastai.vision import *

from modules_matrn.attention import *
from modules_matrn.model import Model, _default_tfmer_cfg
from modules_matrn.transformer import (PositionalEncoding,
                                 TransformerEncoder,
                                 TransformerEncoderLayer)


class BaseSemanticVisual_backbone_feature(Model):
    def __init__(self, config):
        super().__init__(config)
        d_model = ifnone(config.model_alignment_d_model, _default_tfmer_cfg['d_model'])
        nhead = ifnone(config.model_alignment_nhead, _default_tfmer_cfg['nhead'])
        d_inner = ifnone(config.model_alignment_d_inner, _default_tfmer_cfg['d_inner'])
        dropout = ifnone(config.model_alignmentl_dropout, _default_tfmer_cfg['dropout'])
        activation = ifnone(config.model_alignment_activation, _default_tfmer_cfg['activation'])
        num_layers = ifnone(config.model_alignment_num_layers, 2)

        self.mask_example_prob = ifnone(config.model_alignment_mask_example_prob, 0.9)
        self.mask_candidate_prob = ifnone(config.model_alignment_mask_candidate_prob, 0.9)
        self.num_vis_mask = ifnone(config.model_alignment_num_vis_mask, 10)
        self.nhead = nhead

        self.d_model = d_model
        self.use_self_attn = ifnone(config.model_alignment_use_self_attn, False)
        self.loss_weight = ifnone(config.model_alignment_loss_weight, 1.0)
        self.max_length = config.dataset_max_length + 1  # additional stop token
        self.debug = ifnone(config.global_debug, False)

        encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=nhead,
                                                dim_feedforward=d_inner, dropout=dropout, activation=activation)
        self.model1 = TransformerEncoder(encoder_layer, num_layers)
        self.pos_encoder_tfm = PositionalEncoding(d_model, dropout=0, max_len=8*32)

        mode = ifnone(config.model_alignment_attention_mode, 'nearest')
        self.model2_vis = PositionAttention(
            max_length=config.dataset_max_length + 1,  # additional stop token
            mode=mode
        )
        self.cls_vis = nn.Linear(d_model, self.charset.num_classes)
        self.cls_sem = nn.Linear(d_model, self.charset.num_classes)
        self.w_att = nn.Linear(2 * d_model, d_model)

        v_token = torch.empty((1, d_model))
        self.v_token = nn.Parameter(v_token)
        torch.nn.init.uniform_(self.v_token, -0.001, 0.001)

        self.cls = nn.Linear(d_model, self.charset.num_classes)

    def forward(self, l_feature, v_feature, lengths_l=None, v_attn=None, l_logits=None, texts=None, training=True):
        """
        Args:
            l_feature: (N, T, E) where T is length, N is batch size and d is dim of model
            v_feature: (N, E, H, W)
            lengths_l: (N,)
            v_attn: (N, T, H, W)
            l_logits: (N, T, C)
            texts: (N, T, C)
        """
        padding_mask = self._get_padding_mask(lengths_l, self.max_length)

        l_feature = l_feature.permute(1, 0, 2)  # (T, N, E)
        N, E, H, W = v_feature.size()
        v_feature = v_feature.view(N, E, H*W).contiguous().permute(2, 0, 1)  # (H*W, N, E)

        if training:
            n, t, h, w = v_attn.shape
            v_attn = v_attn.view(n, t, -1) # (N, T, H*W)
            for idx, length in enumerate(lengths_l):
                if np.random.random() <= self.mask_example_prob:
                    l_idx = np.random.randint(int(length))
                    v_random_idx = v_attn[idx, l_idx].argsort(descending=True).cpu().numpy()[:self.num_vis_mask,]
                    v_random_idx = v_random_idx[np.random.random(v_random_idx.shape) <= self.mask_candidate_prob]
                    v_feature[v_random_idx, idx] = self.v_token

        if len(v_attn.shape) == 4:
            n, t, h, w = v_attn.shape
            v_attn = v_attn.view(n, t, -1) # (N, T, H*W)

        zeros = v_feature.new_zeros((h*w, n, E))  # (H*W, N, E)
        base_pos = self.pos_encoder_tfm(zeros)  # (H*W, N, E)
        base_pos = base_pos.permute(1, 0, 2) # (N, H*W, E)

        base_pos = torch.bmm(v_attn, base_pos) # (N, T, E)
        base_pos = base_pos.permute(1, 0, 2) # (T, N, E)

        l_feature = l_feature + base_pos

        sv_feature = torch.cat((v_feature, l_feature), dim=0)  # (H*W+T, N, E)
        sv_feature = self.model1(sv_feature)  # (H*W+T, N, E)

        sv_to_v_feature = sv_feature[:H*W]  # (H*W, N, E)
        sv_to_s_feature = sv_feature[H*W:]  # (T, N, E)

        sv_to_v_feature = sv_to_v_feature.permute(1, 2, 0).view(N, E, H, W)
        sv_to_v_feature, _ = self.model2_vis(sv_to_v_feature)  # (N, T, E)
        sv_to_v_logits = self.cls_vis(sv_to_v_feature)  # (N, T, C)
        pt_v_lengths = self._get_length(sv_to_v_logits)  # (N,)

        sv_to_s_feature = sv_to_s_feature.permute(1, 0, 2)  # (N, T, E)
        sv_to_s_logits = self.cls_sem(sv_to_s_feature)  # (N, T, C)
        pt_s_lengths = self._get_length(sv_to_s_logits)  # (N,)

        f = torch.cat((sv_to_v_feature, sv_to_s_feature), dim=2)
        f_att = torch.sigmoid(self.w_att(f))
        output = f_att * sv_to_v_feature + (1 - f_att) * sv_to_s_feature

        logits = self.cls(output)  # (N, T, C)
        pt_lengths = self._get_length(logits)

        return {'logits': logits, 'pt_lengths': pt_lengths, 'loss_weight':self.loss_weight*3,
                'v_logits': sv_to_v_logits, 'pt_v_lengths': pt_v_lengths,
                's_logits': sv_to_s_logits, 'pt_s_lengths': pt_s_lengths,
                'name': 'alignment'}