import math
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
import numpy as np


class StyleAdaptiveLayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-5):
        super().__init__()
        self.in_dim = normalized_shape
        self.norm = nn.LayerNorm(self.in_dim, eps=eps, elementwise_affine=False)
        self.style = nn.Linear(self.in_dim, self.in_dim * 2)
        self.style.bias.data[: self.in_dim] = 1
        self.style.bias.data[self.in_dim :] = 0

    def forward(self, x, condition):
        # x: (B, T, d); condition: (B, T, d)

        style = self.style(torch.mean(condition, dim=1, keepdim=True))

        gamma, beta = style.chunk(2, -1)

        out = self.norm(x)

        out = gamma * out + beta
        return out


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super().__init__()

        self.dropout = dropout
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[: x.size(0)]
        return F.dropout(x, self.dropout, training=self.training)


class TransformerFFNLayer(nn.Module):
    def __init__(
        self, encoder_hidden, conv_filter_size, conv_kernel_size, encoder_dropout
    ):
        super().__init__()

        self.encoder_hidden = encoder_hidden
        self.conv_filter_size = conv_filter_size
        self.conv_kernel_size = conv_kernel_size
        self.encoder_dropout = encoder_dropout

        self.ffn_1 = nn.Conv1d(
            self.encoder_hidden,
            self.conv_filter_size,
            self.conv_kernel_size,
            padding=self.conv_kernel_size // 2,
        )
        self.ffn_1.weight.data.normal_(0.0, 0.02)
        self.ffn_2 = nn.Linear(self.conv_filter_size, self.encoder_hidden)
        self.ffn_2.weight.data.normal_(0.0, 0.02)

    def forward(self, x):
        # x: (B, T, d)
        x = self.ffn_1(x.permute(0, 2, 1)).permute(
            0, 2, 1
        )  # (B, T, d) -> (B, d, T) -> (B, T, d)
        x = F.relu(x)
        x = F.dropout(x, self.encoder_dropout, training=self.training)
        x = self.ffn_2(x)
        return x


class TransformerEncoderLayer(nn.Module):
    def __init__(
        self,
        encoder_hidden,
        encoder_head,
        conv_filter_size,
        conv_kernel_size,
        encoder_dropout,
        use_cln,
    ):
        super().__init__()
        self.encoder_hidden = encoder_hidden
        self.encoder_head = encoder_head
        self.conv_filter_size = conv_filter_size
        self.conv_kernel_size = conv_kernel_size
        self.encoder_dropout = encoder_dropout
        self.use_cln = use_cln

        if not self.use_cln:
            self.ln_1 = nn.LayerNorm(self.encoder_hidden)
            self.ln_2 = nn.LayerNorm(self.encoder_hidden)
        else:
            self.ln_1 = StyleAdaptiveLayerNorm(self.encoder_hidden)
            self.ln_2 = StyleAdaptiveLayerNorm(self.encoder_hidden)

        self.self_attn = nn.MultiheadAttention(
            self.encoder_hidden, self.encoder_head, batch_first=True
        )

        self.ffn = TransformerFFNLayer(
            self.encoder_hidden,
            self.conv_filter_size,
            self.conv_kernel_size,
            self.encoder_dropout,
        )

    def forward(self, x, key_padding_mask, conditon=None):
        # x: (B, T, d); key_padding_mask: (B, T), mask is 0; condition: (B, T, d)

        # self attention
        residual = x
        if self.use_cln:
            x = self.ln_1(x, conditon)
        else:
            x = self.ln_1(x)

        if key_padding_mask != None:
            key_padding_mask_input = ~(key_padding_mask.bool())
        else:
            key_padding_mask_input = None
        x, _ = self.self_attn(
            query=x, key=x, value=x, key_padding_mask=key_padding_mask_input
        )
        x = F.dropout(x, self.encoder_dropout, training=self.training)
        x = residual + x

        # ffn
        residual = x
        if self.use_cln:
            x = self.ln_2(x, conditon)
        else:
            x = self.ln_2(x)
        x = self.ffn(x)
        x = residual + x

        return x


class TransformerEncoder(nn.Module):
    def __init__(
        self,
        enc_emb_tokens=None,
        encoder_layer=None,
        encoder_hidden=None,
        encoder_head=None,
        conv_filter_size=None,
        conv_kernel_size=None,
        encoder_dropout=None,
        use_cln=None,
        cfg=None,
    ):
        super().__init__()

        self.encoder_layer = (
            encoder_layer if encoder_layer is not None else cfg.encoder_layer
        )
        self.encoder_hidden = (
            encoder_hidden if encoder_hidden is not None else cfg.encoder_hidden
        )
        self.encoder_head = (
            encoder_head if encoder_head is not None else cfg.encoder_head
        )
        self.conv_filter_size = (
            conv_filter_size if conv_filter_size is not None else cfg.conv_filter_size
        )
        self.conv_kernel_size = (
            conv_kernel_size if conv_kernel_size is not None else cfg.conv_kernel_size
        )
        self.encoder_dropout = (
            encoder_dropout if encoder_dropout is not None else cfg.encoder_dropout
        )
        self.use_cln = use_cln if use_cln is not None else cfg.use_cln

        if enc_emb_tokens != None:
            self.use_enc_emb = True
            self.enc_emb_tokens = enc_emb_tokens
        else:
            self.use_enc_emb = False

        self.position_emb = PositionalEncoding(
            self.encoder_hidden, self.encoder_dropout
        )

        self.layers = nn.ModuleList([])
        self.layers.extend(
            [
                TransformerEncoderLayer(
                    self.encoder_hidden,
                    self.encoder_head,
                    self.conv_filter_size,
                    self.conv_kernel_size,
                    self.encoder_dropout,
                    self.use_cln,
                )
                for i in range(self.encoder_layer)
            ]
        )

        if self.use_cln:
            self.last_ln = StyleAdaptiveLayerNorm(self.encoder_hidden)
        else:
            self.last_ln = nn.LayerNorm(self.encoder_hidden)

    def forward(self, x, key_padding_mask, condition=None):
        if len(x.shape) == 2 and self.use_enc_emb:
            x = self.enc_emb_tokens(x)
            x = self.position_emb(x)
        else:
            x = self.position_emb(x)  # (B, T, d)

        for layer in self.layers:
            x = layer(x, key_padding_mask, condition)

        if self.use_cln:
            x = self.last_ln(x, condition)
        else:
            x = self.last_ln(x)

        return x


class DurationPredictor(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.input_size = cfg.input_size
        self.filter_size = cfg.filter_size
        self.kernel_size = cfg.kernel_size
        self.conv_layers = cfg.conv_layers
        self.cross_attn_per_layer = cfg.cross_attn_per_layer
        self.attn_head = cfg.attn_head
        self.drop_out = cfg.drop_out

        self.conv = nn.ModuleList()
        self.cattn = nn.ModuleList()

        for idx in range(self.conv_layers):
            in_dim = self.input_size if idx == 0 else self.filter_size
            self.conv += [
                nn.Sequential(
                    nn.Conv1d(
                        in_dim,
                        self.filter_size,
                        self.kernel_size,
                        padding=self.kernel_size // 2,
                    ),
                    nn.ReLU(),
                    nn.LayerNorm(self.filter_size),
                    nn.Dropout(self.drop_out),
                )
            ]
            if idx % self.cross_attn_per_layer == 0:
                self.cattn.append(
                    torch.nn.Sequential(
                        nn.MultiheadAttention(
                            self.filter_size,
                            self.attn_head,
                            batch_first=True,
                            kdim=self.filter_size,
                            vdim=self.filter_size,
                        ),
                        nn.LayerNorm(self.filter_size),
                        nn.Dropout(0.2),
                    )
                )

        self.linear = nn.Linear(self.filter_size, 1)
        self.linear.weight.data.normal_(0.0, 0.02)

    def forward(self, x, mask, ref_emb, ref_mask):
        """
        input:
        x: (B, N, d)
        mask: (B, N), mask is 0
        ref_emb: (B, d, T')
        ref_mask: (B, T'), mask is 0

        output:
        dur_pred: (B, N)
        dur_pred_log: (B, N)
        dur_pred_round: (B, N)
        """

        input_ref_mask = ~(ref_mask.bool())  # (B, T')
        # print(input_ref_mask)

        x = x.transpose(1, -1)  # (B, N, d) -> (B, d, N)

        for idx, (conv, act, ln, dropout) in enumerate(self.conv):
            res = x
            # print(torch.min(x), torch.max(x))
            if idx % self.cross_attn_per_layer == 0:
                attn_idx = idx // self.cross_attn_per_layer
                attn, attn_ln, attn_drop = self.cattn[attn_idx]

                attn_res = y_ = x.transpose(1, 2)  # (B, d, N) -> (B, N, d)

                y_ = attn_ln(y_)
                # print(torch.min(y_), torch.min(y_))
                # print(torch.min(ref_emb), torch.max(ref_emb))
                y_, _ = attn(
                    y_,
                    ref_emb.transpose(1, 2),
                    ref_emb.transpose(1, 2),
                    key_padding_mask=input_ref_mask,
                )
                # y_, _ = attn(y_, ref_emb.transpose(1, 2), ref_emb.transpose(1, 2))
                # print(torch.min(y_), torch.min(y_))
                y_ = attn_drop(y_)
                y_ = (y_ + attn_res) / math.sqrt(2.0)

                x = y_.transpose(1, 2)

            x = conv(x)
            # print(torch.min(x), torch.max(x))
            x = act(x)
            x = ln(x.transpose(1, 2))
            # print(torch.min(x), torch.max(x))
            x = x.transpose(1, 2)

            x = dropout(x)

            if idx != 0:
                x += res

            if mask is not None:
                x = x * mask.to(x.dtype)[:, None, :]

        x = self.linear(x.transpose(1, 2))
        x = torch.squeeze(x, -1)

        dur_pred = x.exp() - 1
        dur_pred_round = torch.clamp(torch.round(x.exp() - 1), min=0).long()

        return {
            "dur_pred_log": x,
            "dur_pred": dur_pred,
            "dur_pred_round": dur_pred_round,
        }


class PitchPredictor(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.input_size = cfg.input_size
        self.filter_size = cfg.filter_size
        self.kernel_size = cfg.kernel_size
        self.conv_layers = cfg.conv_layers
        self.cross_attn_per_layer = cfg.cross_attn_per_layer
        self.attn_head = cfg.attn_head
        self.drop_out = cfg.drop_out

        self.conv = nn.ModuleList()
        self.cattn = nn.ModuleList()

        for idx in range(self.conv_layers):
            in_dim = self.input_size if idx == 0 else self.filter_size
            self.conv += [
                nn.Sequential(
                    nn.Conv1d(
                        in_dim,
                        self.filter_size,
                        self.kernel_size,
                        padding=self.kernel_size // 2,
                    ),
                    nn.ReLU(),
                    nn.LayerNorm(self.filter_size),
                    nn.Dropout(self.drop_out),
                )
            ]
            if idx % self.cross_attn_per_layer == 0:
                self.cattn.append(
                    torch.nn.Sequential(
                        nn.MultiheadAttention(
                            self.filter_size,
                            self.attn_head,
                            batch_first=True,
                            kdim=self.filter_size,
                            vdim=self.filter_size,
                        ),
                        nn.LayerNorm(self.filter_size),
                        nn.Dropout(0.2),
                    )
                )

        self.linear = nn.Linear(self.filter_size, 1)
        self.linear.weight.data.normal_(0.0, 0.02)

    def forward(self, x, mask, ref_emb, ref_mask):
        """
        input:
        x: (B, N, d)
        mask: (B, N), mask is 0
        ref_emb: (B, d, T')
        ref_mask: (B, T'), mask is 0

        output:
        pitch_pred: (B, T)
        """

        input_ref_mask = ~(ref_mask.bool())  # (B, T')

        x = x.transpose(1, -1)  # (B, N, d) -> (B, d, N)

        for idx, (conv, act, ln, dropout) in enumerate(self.conv):
            res = x
            if idx % self.cross_attn_per_layer == 0:
                attn_idx = idx // self.cross_attn_per_layer
                attn, attn_ln, attn_drop = self.cattn[attn_idx]

                attn_res = y_ = x.transpose(1, 2)  # (B, d, N) -> (B, N, d)

                y_ = attn_ln(y_)
                y_, _ = attn(
                    y_,
                    ref_emb.transpose(1, 2),
                    ref_emb.transpose(1, 2),
                    key_padding_mask=input_ref_mask,
                )
                # y_, _ = attn(y_, ref_emb.transpose(1, 2), ref_emb.transpose(1, 2))
                y_ = attn_drop(y_)
                y_ = (y_ + attn_res) / math.sqrt(2.0)

                x = y_.transpose(1, 2)

            x = conv(x)
            x = act(x)
            x = ln(x.transpose(1, 2))
            x = x.transpose(1, 2)

            x = dropout(x)

            if idx != 0:
                x += res

        x = self.linear(x.transpose(1, 2))
        x = torch.squeeze(x, -1)

        return x


def pad(input_ele, mel_max_length=None):
    if mel_max_length:
        max_len = mel_max_length
    else:
        max_len = max([input_ele[i].size(0) for i in range(len(input_ele))])

    out_list = list()
    for i, batch in enumerate(input_ele):
        if len(batch.shape) == 1:
            one_batch_padded = F.pad(
                batch, (0, max_len - batch.size(0)), "constant", 0.0
            )
        elif len(batch.shape) == 2:
            one_batch_padded = F.pad(
                batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0
            )
        out_list.append(one_batch_padded)
    out_padded = torch.stack(out_list)
    return out_padded


class LengthRegulator(nn.Module):
    """Length Regulator"""

    def __init__(self):
        super(LengthRegulator, self).__init__()

    def LR(self, x, duration, max_len):
        device = x.device
        output = list()
        mel_len = list()
        for batch, expand_target in zip(x, duration):
            expanded = self.expand(batch, expand_target)
            output.append(expanded)
            mel_len.append(expanded.shape[0])

        if max_len is not None:
            output = pad(output, max_len)
        else:
            output = pad(output)

        return output, torch.LongTensor(mel_len).to(device)

    def expand(self, batch, predicted):
        out = list()

        for i, vec in enumerate(batch):
            expand_size = predicted[i].item()
            out.append(vec.expand(max(int(expand_size), 0), -1))
        out = torch.cat(out, 0)

        return out

    def forward(self, x, duration, max_len):
        output, mel_len = self.LR(x, duration, max_len)
        return output, mel_len