import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np


class ConvGRU(nn.Module):
    def __init__(self, hidden_dim=128, input_dim=192 + 128):
        super(ConvGRU, self).__init__()
        self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
        self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
        self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)

    def forward(self, h, x):
        hx = torch.cat([h, x], dim=1)

        z = torch.sigmoid(self.convz(hx))
        r = torch.sigmoid(self.convr(hx))
        q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))

        h = (1 - z) * h + z * q
        return h


class SepConvGRU(nn.Module):
    def __init__(self, hidden_dim=128, input_dim=192 + 128):
        super(SepConvGRU, self).__init__()
        self.convz1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2))
        self.convr1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2))
        self.convq1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2))

        self.convz2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0))
        self.convr2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0))
        self.convq2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0))

    def forward(self, h, x):
        # horizontal
        hx = torch.cat([h, x], dim=1)
        z = torch.sigmoid(self.convz1(hx))
        r = torch.sigmoid(self.convr1(hx))
        q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
        h = (1 - z) * h + z * q

        # vertical
        hx = torch.cat([h, x], dim=1)
        z = torch.sigmoid(self.convz2(hx))
        r = torch.sigmoid(self.convr2(hx))
        q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
        h = (1 - z) * h + z * q

        return h


class GRU(nn.Module):
    def __init__(self, hidden_dim=128, input_dim=192 + 128):
        super(GRU, self).__init__()
        self.convz1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2))
        self.convr1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2))
        self.convq1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2))

        self.convz2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0))
        self.convr2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0))
        self.convq2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0))

    def forward(self, hidden, x, shape):
        # horizontal
        b, l, c = hidden.shape
        h, w = shape
        hidden = hidden.view(b, h, w, c).permute(0, 3, 1, 2).contiguous()
        x = x.view(b, h, w, c).permute(0, 3, 1, 2).contiguous()

        hx = torch.cat([hidden, x], dim=1)
        z = torch.sigmoid(self.convz1(hx))
        r = torch.sigmoid(self.convr1(hx))
        q = torch.tanh(self.convq1(torch.cat([r * hidden, x], dim=1)))
        hidden = (1 - z) * hidden + z * q

        # vertical
        hx = torch.cat([hidden, x], dim=1)
        z = torch.sigmoid(self.convz2(hx))
        r = torch.sigmoid(self.convr2(hx))
        q = torch.tanh(self.convq2(torch.cat([r * hidden, x], dim=1)))
        hidden = (1 - z) * hidden + z * q

        return hidden.flatten(-2).permute(0, 2, 1)


class PositionEmbeddingSine(nn.Module):
    """
    This is a more standard version of the position embedding, very similar to the one
    used by the Attention is all you need paper, generalized to work on images.
    """

    def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        if scale is None:
            scale = 2 * math.pi
        self.scale = scale

    def forward(self, x):
        # x = tensor_list.tensors  # [B, C, H, W]
        # mask = tensor_list.mask  # [B, H, W], input with padding, valid as 0
        b, c, h, w = x.size()
        mask = torch.ones((b, h, w), device=x.device)  # [B, H, W]
        y_embed = mask.cumsum(1, dtype=torch.float32)
        x_embed = mask.cumsum(2, dtype=torch.float32)
        #
        # y_embed = (y_embed / 2) ** 2
        # x_embed = (x_embed / 2) ** 2

        if self.normalize:
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

            # using an exponential
        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos


def feature_add_position(feature0, feature_channels, scale=1.0):
    temp = torch.mean(abs(feature0))
    pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2)
    # position = PositionalEncodingPermute2D(feature_channels)(feature0)
    position = pos_enc(feature0)
    feature0 = feature0 + (temp * position / position.mean()) * scale * torch.pi
    feature0 = feature0 * temp / torch.mean(abs(feature0), dim=(1, 2, 3), keepdim=True)
    return feature0


def feature_add_image_content(feature0, add_fea, scale=0.4):
    temp = torch.mean(abs(feature0))
    position = add_fea
    feature0 = feature0 + (temp * position / position.mean()) * scale * torch.pi
    feature0 = feature0 * temp / torch.mean(abs(feature0), dim=(1, 2, 3), keepdim=True)
    return feature0


class AttUp(nn.Module):
    def __init__(self,
                 c=512
                 ):
        super(AttUp, self).__init__()
        self.proj = nn.Linear(c, c, bias=False)
        self.norm = nn.LayerNorm(c)
        self.conv = nn.Sequential(nn.Conv2d(2 * c, c, kernel_size=1, stride=1, padding=0),
                                  nn.GELU(),
                                  nn.Conv2d(c, c, kernel_size=3, stride=1, padding=1),
                                  nn.GELU(),
                                  nn.Conv2d(c, c, kernel_size=3, stride=1, padding=1),
                                  nn.GELU()
                                  )
        self.gru = SepConvGRU(c, c)

    def forward(self, att, message, shape):
        # q, k, v: [B, L, C]
        b, l, c = att.shape
        h, w = shape
        message = self.norm(self.proj(message)).view(b, h, w, c).permute(0, 3, 1, 2).contiguous()
        att = att.view(b, h, w, c).permute(0, 3, 1, 2).contiguous()
        message = self.conv(torch.cat([att, message], dim=1))
        att = self.gru(att, message).flatten(-2).permute(0, 2, 1)
        # [B, H*W, C]
        return att


class TransformerLayer(nn.Module):
    def __init__(self,
                 d_model=256,
                 nhead=1,
                 no_ffn=False,
                 ffn_dim_expansion=4
                 ):
        super(TransformerLayer, self).__init__()

        self.dim = d_model
        self.nhead = nhead
        self.no_ffn = no_ffn
        # multi-head attention
        self.att_proj = nn.Sequential(nn.Linear(d_model, d_model, bias=False), nn.ReLU(inplace=True),
                                      nn.Linear(d_model, d_model, bias=False))
        self.v_proj = nn.Linear(d_model, d_model, bias=False)
        self.merge = nn.Linear(d_model, d_model, bias=False)
        self.gru = GRU(d_model, d_model)
        self.attn_updater = AttUp(d_model)
        self.drop = nn.Dropout(p=0.8)

        self.norm1 = nn.LayerNorm(d_model)

        # no ffn after self-attn, with ffn after cross-attn
        if not self.no_ffn:
            in_channels = d_model * 2
            self.mlp = nn.Sequential(
                nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False),
                nn.GELU(),
                nn.Linear(in_channels * ffn_dim_expansion, in_channels * ffn_dim_expansion, bias=False),
                nn.GELU(),
                nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False),
            )

            self.norm2 = nn.LayerNorm(d_model)

    def forward(self, att, value,
                shape, iteration=0):
        # source, target: [B, L, C]
        max_exp_scale = 3 * torch.pi
        # single-head attention
        B, L, C = value.shape
        if iteration == 0:
            att = feature_add_position(att.transpose(-1, -2).view(
                B, C, shape[0], shape[1]), C).reshape(B, C, -1).transpose(-1, -2)

            # att = feature_add_position(att.transpose(-1, -2).view(
            #     B, C, shape[0], shape[1]), C).reshape(B, C, -1).transpose(-1, -2)
        val_proj = self.v_proj(value)
        att_proj = self.att_proj(att)  # [B, L, C]
        norm_fac = torch.sum(att_proj ** 2, dim=-1, keepdim=True) ** 0.5
        scale = max_exp_scale * torch.sigmoid(torch.mean(att_proj, dim=[-1, -2], keepdim=True)) + 1
        A = torch.exp(scale * torch.matmul(att_proj / norm_fac, att_proj.permute(0, 2, 1) / norm_fac.permute(0, 2, 1)))
        A = A / A.max()
        # I = torch.eye(A.shape[-1], device=A.device).unsqueeze(0)
        # # A[I.repeat(B, 1, 1) == 1] = 1e-6  # remove self-prop
        D = torch.sum(A, dim=-1, keepdim=True)
        D = 1 / (torch.sqrt(D) + 1e-6)  # normalized node degrees
        A = D * A * D.transpose(-1, -2)

        # A = torch.softmax(A , dim=2)  # [B, L, L]
        message = torch.matmul(A, val_proj)  # [B, L, C]

        message = self.merge(message)  # [B, L, C]
        message = self.norm1(message)
        if not self.no_ffn:
            message = self.mlp(torch.cat([value, message], dim=-1))
            message = self.norm2(message)

        # if iteration > 2:
        #     message = self.drop(message)

        att = self.attn_updater(att, message, shape)
        value = self.gru(value, message, shape)
        return value, att, A


class FeatureTransformer(nn.Module):
    def __init__(self,
                 num_layers=6,
                 d_model=128
                 ):
        super(FeatureTransformer, self).__init__()
        self.d_model = d_model
        # self.layers = nn.ModuleList([TransformerLayer(self.d_model, no_ffn=False, ffn_dim_expansion=2)
        #                              for i in range(num_layers)])
        self.layers = TransformerLayer(self.d_model, no_ffn=False, ffn_dim_expansion=2)
        self.re_proj = nn.Sequential(nn.Linear(d_model, d_model), nn.GELU(), nn.Linear(d_model, d_model))
        self.num_layers = num_layers
        self.norm_sigma = nn.Parameter(torch.tensor(1.0, requires_grad=True), requires_grad=True)
        self.norm_k = nn.Parameter(torch.tensor(1.8, requires_grad=True), requires_grad=True)

        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def normalize(self, x):  # TODO
        sum_activation = torch.mean(x, dim=[1, 2], keepdim=True) + torch.square(self.norm_sigma)
        x = self.norm_k.abs() * x / sum_activation
        return x

    def forward(self, feature0):

        feature_list = []
        attn_list = []
        attn_viz_list = []
        b, c, h, w = feature0.shape
        assert self.d_model == c
        value = feature0.flatten(-2).permute(0, 2, 1)  # [B, H*W, C]
        att = feature0
        att = att.flatten(-2).permute(0, 2, 1)  # [B, H*W, C]
        for i in range(self.num_layers):
            value, att, attn_viz = self.layers(att=att, value=value, shape=[h, w], iteration=i)
            attn_viz = attn_viz.reshape(b, h, w, h, w)
            attn_viz_list.append(attn_viz)
            value_decode = self.normalize(
                torch.square(self.re_proj(value)))  # map to motion energy, Do use normalization here
            # print("value_decode",value_decode.abs().mean())
            attn_list.append(att.view(b, h, w, c).permute(0, 3, 1, 2).contiguous())
            feature_list.append(value_decode.view(b, h, w, c).permute(0, 3, 1, 2).contiguous())
        # reshape back
        return feature_list, attn_list, attn_viz_list

    def forward_save_mem(self, feature0, add_position_embedding=True):
        feature_list = []
        attn_list = []
        attn_viz_list = []
        b, c, h, w = feature0.shape
        assert self.d_model == c
        value = feature0.flatten(-2).permute(0, 2, 1)  # [B, H*W, C]
        att = feature0
        att = att.flatten(-2).permute(0, 2, 1)  # [B, H*W, C]
        for i in range(self.num_layers):
            value, att, _ = self.layers(att=att, value=value, shape=[h, w], iteration=i)
            value_decode = self.normalize(
                torch.square(self.re_proj(value)))  # map to motion energy, Do use normalization here
            # print("value_decode",value_decode.abs().mean())
            attn_list.append(att.view(b, h, w, c).permute(0, 3, 1, 2).contiguous())
            feature_list.append(value_decode.view(b, h, w, c).permute(0, 3, 1, 2).contiguous())
        # reshape back
        return feature_list, attn_list

    @staticmethod
    def demo():
        import time
        frame_list = torch.randn([4, 256, 64, 64], device="cuda")
        model = FeatureTransformer(6, 256).cuda()
        for i in range(100):
            start = time.time()
            output = model(frame_list)

            torch.mean(output[-1][-1]).backward()
            end = time.time()
            print(end - start)
            print("#================================++#")


if __name__ == '__main__':
    FeatureTransformer.demo()