"""
Tacotron + stepwise monotonic attention
"""

import jax
import jax.numpy as jnp
import pax


def conv_block(in_ft, out_ft, kernel_size, activation_fn, use_dropout):
    """
    Conv >> LayerNorm >> activation >> Dropout
    """
    f = pax.Sequential(
        pax.Conv1D(in_ft, out_ft, kernel_size, with_bias=False),
        pax.LayerNorm(out_ft, -1, True, True),
    )
    if activation_fn is not None:
        f >>= activation_fn
    if use_dropout:
        f >>= pax.Dropout(0.5)
    return f


class HighwayBlock(pax.Module):
    """
    Highway block
    """

    def __init__(self, dim: int) -> None:
        super().__init__()
        self.dim = dim
        self.fc = pax.Linear(dim, 2 * dim)

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        t, h = jnp.split(self.fc(x), 2, axis=-1)
        t = jax.nn.sigmoid(t - 1.0)  # bias toward keeping x
        h = jax.nn.relu(h)
        x = x * (1.0 - t) + h * t
        return x


class BiGRU(pax.Module):
    """
    Bidirectional GRU
    """

    def __init__(self, dim):
        super().__init__()

        self.rnn_fwd = pax.GRU(dim, dim)
        self.rnn_bwd = pax.GRU(dim, dim)

    def __call__(self, x, reset_masks):
        N = x.shape[0]
        x_fwd = x
        x_bwd = jnp.flip(x, axis=1)
        x_fwd_states = self.rnn_fwd.initial_state(N)
        x_bwd_states = self.rnn_bwd.initial_state(N)
        x_fwd_states, x_fwd = pax.scan(
            self.rnn_fwd, x_fwd_states, x_fwd, time_major=False
        )

        reset_masks = jnp.flip(reset_masks, axis=1)
        x_bwd_states0 = x_bwd_states

        def rnn_reset_core(prev, inputs):
            x, reset_mask = inputs

            def reset_state(x0, xt):
                return jnp.where(reset_mask, x0, xt)

            state, _ = self.rnn_bwd(prev, x)
            state = jax.tree_map(reset_state, x_bwd_states0, state)
            return state, state.hidden

        x_bwd_states, x_bwd = pax.scan(
            rnn_reset_core, x_bwd_states, (x_bwd, reset_masks), time_major=False
        )
        x_bwd = jnp.flip(x_bwd, axis=1)
        x = jnp.concatenate((x_fwd, x_bwd), axis=-1)
        return x


class CBHG(pax.Module):
    """
    Conv Bank >> Highway net >> GRU
    """

    def __init__(self, dim):
        super().__init__()
        self.convs = [conv_block(dim, dim, i, jax.nn.relu, False) for i in range(1, 17)]
        self.conv_projection_1 = conv_block(16 * dim, dim, 3, jax.nn.relu, False)
        self.conv_projection_2 = conv_block(dim, dim, 3, None, False)

        self.highway = pax.Sequential(
            HighwayBlock(dim), HighwayBlock(dim), HighwayBlock(dim), HighwayBlock(dim)
        )
        self.rnn = BiGRU(dim)

    def __call__(self, x, x_mask):
        conv_input = x * x_mask
        fts = [f(conv_input) for f in self.convs]
        residual = jnp.concatenate(fts, axis=-1)
        residual = pax.max_pool(residual, 2, 1, "SAME", -1)
        residual = self.conv_projection_1(residual * x_mask)
        residual = self.conv_projection_2(residual * x_mask)
        x = x + residual
        x = self.highway(x)
        x = self.rnn(x * x_mask, reset_masks=1 - x_mask)
        return x * x_mask


class PreNet(pax.Module):
    """
    Linear >> relu >> dropout >> Linear >> relu >> dropout
    """

    def __init__(self, input_dim, hidden_dim, output_dim, always_dropout=True):
        super().__init__()
        self.fc1 = pax.Linear(input_dim, hidden_dim)
        self.fc2 = pax.Linear(hidden_dim, output_dim)
        self.rng_seq = pax.RngSeq()
        self.always_dropout = always_dropout

    def __call__(self, x, k1=None, k2=None):
        x = self.fc1(x)
        x = jax.nn.relu(x)
        if self.always_dropout or self.training:
            if k1 is None:
                k1 = self.rng_seq.next_rng_key()
            x = pax.dropout(k1, 0.5, x)
        x = self.fc2(x)
        x = jax.nn.relu(x)
        if self.always_dropout or self.training:
            if k2 is None:
                k2 = self.rng_seq.next_rng_key()
            x = pax.dropout(k2, 0.5, x)
        return x


class Tacotron(pax.Module):
    """
    Tacotron TTS model.

    It uses stepwise monotonic attention for robust attention.
    """

    def __init__(
        self,
        mel_dim: int,
        attn_bias,
        rr,
        max_rr,
        mel_min,
        sigmoid_noise,
        pad_token,
        prenet_dim,
        attn_hidden_dim,
        attn_rnn_dim,
        rnn_dim,
        postnet_dim,
        text_dim,
    ):
        """
        New Tacotron model

        Args:
            mel_dim (int): dimension of log mel-spectrogram features.
            attn_bias (float): control how "slow" the attention will
                move forward at initialization.
            rr (int): the reduction factor.
                Number of predicted frame at each time step. Default is 2.
            max_rr (int): max value of rr.
            mel_min (float): the minimum value of mel features.
                The <go> frame is filled by `log(mel_min)` values.
            sigmoid_noise (float): the variance of gaussian noise added
                to attention scores in training.
            pad_token (int): the pad value at the end of text sequences.
            prenet_dim (int): dimension of prenet output.
            attn_hidden_dim (int): dimension of attention hidden vectors.
            attn_rnn_dim (int): number of cells in the attention RNN.
            rnn_dim (int): number of cells in the decoder RNNs.
            postnet_dim (int): number of features in the postnet convolutions.
            text_dim (int): dimension of text embedding vectors.
        """
        super().__init__()
        self.text_dim = text_dim
        assert rr <= max_rr
        self.rr = rr
        self.max_rr = max_rr
        self.mel_dim = mel_dim
        self.mel_min = mel_min
        self.sigmoid_noise = sigmoid_noise
        self.pad_token = pad_token
        self.prenet_dim = prenet_dim

        # encoder submodules
        self.encoder_embed = pax.Embed(256, text_dim)
        self.encoder_pre_net = PreNet(text_dim, 256, prenet_dim, always_dropout=True)
        self.encoder_cbhg = CBHG(prenet_dim)

        # random key generator
        self.rng_seq = pax.RngSeq()

        # pre-net
        self.decoder_pre_net = PreNet(mel_dim, 256, prenet_dim, always_dropout=True)

        # decoder submodules
        self.attn_rnn = pax.LSTM(prenet_dim + prenet_dim * 2, attn_rnn_dim)
        self.text_key_fc = pax.Linear(prenet_dim * 2, attn_hidden_dim, with_bias=True)
        self.attn_query_fc = pax.Linear(attn_rnn_dim, attn_hidden_dim, with_bias=False)

        self.attn_V = pax.Linear(attn_hidden_dim, 1, with_bias=False)
        self.attn_V_weight_norm = jnp.array(1.0 / jnp.sqrt(attn_hidden_dim))
        self.attn_V_bias = jnp.array(attn_bias)
        self.attn_log = jnp.zeros((1,))
        self.decoder_input = pax.Linear(attn_rnn_dim + 2 * prenet_dim, rnn_dim)
        self.decoder_rnn1 = pax.LSTM(rnn_dim, rnn_dim)
        self.decoder_rnn2 = pax.LSTM(rnn_dim, rnn_dim)
        # mel + end-of-sequence token
        self.output_fc = pax.Linear(rnn_dim, (mel_dim + 1) * max_rr, with_bias=True)

        # post-net
        self.post_net = pax.Sequential(
            conv_block(mel_dim, postnet_dim, 5, jax.nn.tanh, True),
            conv_block(postnet_dim, postnet_dim, 5, jax.nn.tanh, True),
            conv_block(postnet_dim, postnet_dim, 5, jax.nn.tanh, True),
            conv_block(postnet_dim, postnet_dim, 5, jax.nn.tanh, True),
            conv_block(postnet_dim, mel_dim, 5, None, True),
        )

    parameters = pax.parameters_method("attn_V_weight_norm", "attn_V_bias")

    def encode_text(self, text: jnp.ndarray) -> jnp.ndarray:
        """
        Encode text to a sequence of real vectors
        """
        N, L = text.shape
        text_mask = (text != self.pad_token)[..., None]
        x = self.encoder_embed(text)
        x = self.encoder_pre_net(x)
        x = self.encoder_cbhg(x, text_mask)
        return x

    def go_frame(self, batch_size: int) -> jnp.ndarray:
        """
        return the go frame
        """
        return jnp.ones((batch_size, self.mel_dim)) * jnp.log(self.mel_min)

    def decoder_initial_state(self, N: int, L: int):
        """
        setup decoder initial state
        """
        attn_context = jnp.zeros((N, self.prenet_dim * 2))
        attn_pr = jax.nn.one_hot(
            jnp.zeros((N,), dtype=jnp.int32), num_classes=L, axis=-1
        )

        attn_state = (self.attn_rnn.initial_state(N), attn_context, attn_pr)
        decoder_rnn_states = (
            self.decoder_rnn1.initial_state(N),
            self.decoder_rnn2.initial_state(N),
        )
        return attn_state, decoder_rnn_states

    def monotonic_attention(self, prev_state, inputs, envs):
        """
        Stepwise monotonic attention
        """
        attn_rnn_state, attn_context, prev_attn_pr = prev_state
        x, attn_rng_key = inputs
        text, text_key = envs
        attn_rnn_input = jnp.concatenate((x, attn_context), axis=-1)
        attn_rnn_state, attn_rnn_output = self.attn_rnn(attn_rnn_state, attn_rnn_input)
        attn_query_input = attn_rnn_output
        attn_query = self.attn_query_fc(attn_query_input)
        attn_hidden = jnp.tanh(attn_query[:, None, :] + text_key)
        score = self.attn_V(attn_hidden)
        score = jnp.squeeze(score, axis=-1)
        weight_norm = jnp.linalg.norm(self.attn_V.weight)
        score = score * (self.attn_V_weight_norm / weight_norm)
        score = score + self.attn_V_bias
        noise = jax.random.normal(attn_rng_key, score.shape) * self.sigmoid_noise
        pr_stay = jax.nn.sigmoid(score + noise)
        pr_move = 1.0 - pr_stay
        pr_new_location = pr_move * prev_attn_pr
        pr_new_location = jnp.pad(
            pr_new_location[:, :-1], ((0, 0), (1, 0)), constant_values=0
        )
        attn_pr = pr_stay * prev_attn_pr + pr_new_location
        attn_context = jnp.einsum("NL,NLD->ND", attn_pr, text)
        new_state = (attn_rnn_state, attn_context, attn_pr)
        return new_state, attn_rnn_output

    def zoneout_lstm(self, lstm_core, rng_key, zoneout_pr=0.1):
        """
        Return a zoneout lstm core.

        It will zoneout the new hidden states and keep the new cell states unchanged.
        """

        def core(state, x):
            new_state, _ = lstm_core(state, x)
            h_old = state.hidden
            h_new = new_state.hidden
            mask = jax.random.bernoulli(rng_key, zoneout_pr, h_old.shape)
            h_new = h_old * mask + h_new * (1.0 - mask)
            return pax.LSTMState(h_new, new_state.cell), h_new

        return core

    def decoder_step(
        self,
        attn_state,
        decoder_rnn_states,
        rng_key,
        mel,
        text,
        text_key,
        call_pre_net=False,
    ):
        """
        One decoder step
        """
        if call_pre_net:
            k1, k2, zk1, zk2, rng_key, rng_key_next = jax.random.split(rng_key, 6)
            mel = self.decoder_pre_net(mel, k1, k2)
        else:
            zk1, zk2, rng_key, rng_key_next = jax.random.split(rng_key, 4)
        attn_inputs = (mel, rng_key)
        attn_envs = (text, text_key)
        attn_state, attn_rnn_output = self.monotonic_attention(
            attn_state, attn_inputs, attn_envs
        )
        (_, attn_context, attn_pr) = attn_state
        (decoder_rnn_state1, decoder_rnn_state2) = decoder_rnn_states
        decoder_rnn1_input = jnp.concatenate((attn_rnn_output, attn_context), axis=-1)
        decoder_rnn1_input = self.decoder_input(decoder_rnn1_input)
        decoder_rnn1 = self.zoneout_lstm(self.decoder_rnn1, zk1)
        decoder_rnn_state1, decoder_rnn_output1 = decoder_rnn1(
            decoder_rnn_state1, decoder_rnn1_input
        )
        decoder_rnn2_input = decoder_rnn1_input + decoder_rnn_output1
        decoder_rnn2 = self.zoneout_lstm(self.decoder_rnn2, zk2)
        decoder_rnn_state2, decoder_rnn_output2 = decoder_rnn2(
            decoder_rnn_state2, decoder_rnn2_input
        )
        x = decoder_rnn1_input + decoder_rnn_output1 + decoder_rnn_output2
        decoder_rnn_states = (decoder_rnn_state1, decoder_rnn_state2)
        return attn_state, decoder_rnn_states, rng_key_next, x, attn_pr[0]

    @jax.jit
    def inference_step(
        self, attn_state, decoder_rnn_states, rng_key, mel, text, text_key
    ):
        """one inference step"""
        attn_state, decoder_rnn_states, rng_key, x, _ = self.decoder_step(
            attn_state,
            decoder_rnn_states,
            rng_key,
            mel,
            text,
            text_key,
            call_pre_net=True,
        )
        x = self.output_fc(x)
        N, D2 = x.shape
        x = jnp.reshape(x, (N, self.max_rr, D2 // self.max_rr))
        x = x[:, : self.rr, :]
        x = jnp.reshape(x, (N, self.rr, -1))
        mel = x[..., :-1]
        eos_logit = x[..., -1]
        eos_pr = jax.nn.sigmoid(eos_logit[0, -1])
        eos_pr = jnp.where(eos_pr < 0.1, 0.0, eos_pr)
        rng_key, eos_rng_key = jax.random.split(rng_key)
        eos = jax.random.bernoulli(eos_rng_key, p=eos_pr)
        return attn_state, decoder_rnn_states, rng_key, (mel, eos)

    def inference(self, text, seed=42, max_len=1000):
        """
        text to mel
        """
        text = self.encode_text(text)
        text_key = self.text_key_fc(text)
        N, L, D = text.shape
        assert N == 1
        mel = self.go_frame(N)

        attn_state, decoder_rnn_states = self.decoder_initial_state(N, L)
        rng_key = jax.random.PRNGKey(seed)
        mels = []
        count = 0
        while True:
            count = count + 1
            attn_state, decoder_rnn_states, rng_key, (mel, eos) = self.inference_step(
                attn_state, decoder_rnn_states, rng_key, mel, text, text_key
            )
            mels.append(mel)
            if eos.item() or count > max_len:
                break

            mel = mel[:, -1, :]

        mels = jnp.concatenate(mels, axis=1)
        mel = mel + self.post_net(mel)
        return mels

    def decode(self, mel, text):
        """
        Attention mechanism + Decoder
        """
        text_key = self.text_key_fc(text)

        def scan_fn(prev_states, inputs):
            attn_state, decoder_rnn_states = prev_states
            x, rng_key = inputs
            attn_state, decoder_rnn_states, _, output, attn_pr = self.decoder_step(
                attn_state, decoder_rnn_states, rng_key, x, text, text_key
            )
            states = (attn_state, decoder_rnn_states)
            return states, (output, attn_pr)

        N, L, D = text.shape
        decoder_states = self.decoder_initial_state(N, L)
        rng_keys = self.rng_seq.next_rng_key(mel.shape[1])
        rng_keys = jnp.stack(rng_keys, axis=1)
        decoder_states, (x, attn_log) = pax.scan(
            scan_fn,
            decoder_states,
            (mel, rng_keys),
            time_major=False,
        )
        self.attn_log = attn_log
        del decoder_states
        x = self.output_fc(x)

        N, T2, D2 = x.shape
        x = jnp.reshape(x, (N, T2, self.max_rr, D2 // self.max_rr))
        x = x[:, :, : self.rr, :]
        x = jnp.reshape(x, (N, T2 * self.rr, -1))
        mel = x[..., :-1]
        eos = x[..., -1]
        return mel, eos

    def __call__(self, mel: jnp.ndarray, text: jnp.ndarray):
        text = self.encode_text(text)
        mel = self.decoder_pre_net(mel)
        mel, eos = self.decode(mel, text)
        return mel, mel + self.post_net(mel), eos