from typing import List

import torch
import torch.distributions as tdist
import torch.nn.functional as F
from torch import nn
from torch.utils.checkpoint import checkpoint

from TTS.tts.layers.overflow.common_layers import Outputnet, OverflowUtils
from TTS.tts.layers.tacotron.common_layers import Prenet
from TTS.tts.utils.helpers import sequence_mask

class NeuralHMM(nn.Module):
    """Autoregressive left to right HMM model primarily used in "Neural HMMs are all you need (for high-quality attention-free TTS)"


    Paper abstract::
        Neural sequence-to-sequence TTS has achieved significantly better output quality than statistical speech synthesis using
        HMMs. However, neural TTS is generally not probabilistic and uses non-monotonic attention. Attention failures increase
        training time and can make synthesis babble incoherently. This paper describes how the old and new paradigms can be
        combined to obtain the advantages of both worlds, by replacing attention in neural TTS with an autoregressive left-right
        no-skip hidden Markov model defined by a neural network. Based on this proposal, we modify Tacotron 2 to obtain an
        HMM-based neural TTS model with monotonic alignment, trained to maximise the full sequence likelihood without
        approximation. We also describe how to combine ideas from classical and contemporary TTS for best results. The resulting
        example system is smaller and simpler than Tacotron 2, and learns to speak with fewer iterations and less data, whilst
        achieving comparable naturalness prior to the post-net. Our approach also allows easy control over speaking rate.

        frame_channels (int): Output dimension to generate.
        ar_order (int): Autoregressive order of the model. In ablations of Neural HMM it was found that more autoregression while giving more variation hurts naturalness of the synthesised audio.
        deterministic_transition (bool): deterministic duration generation based on duration quantiles as defiend in "S. Ronanki, O. Watts, S. King, and G. E. Henter, “Medianbased generation of synthetic speech durations using a nonparametric approach,” in Proc. SLT, 2016.". Defaults to True.
        encoder_dim (int): Channels of encoder input and character embedding tensors. Defaults to 512.
        prenet_type (str): `original` or `bn`. `original` sets the default Prenet and `bn` uses Batch Normalization version of the Prenet.
        prenet_dim (int): Dimension of the Prenet.
        prenet_n_layers (int): Number of layers in the Prenet.
        prenet_dropout (float): Dropout probability of the Prenet.
        prenet_dropout_at_inference (bool): If True, dropout is applied at inference time.
        memory_rnn_dim (int): Size of the memory RNN to process output of prenet.
        outputnet_size (List[int]): Size of the output network inside the neural HMM.
        flat_start_params (dict): Parameters for the flat start initialization of the neural HMM.
        std_floor (float): Floor value for the standard deviation of the neural HMM. Prevents model cheating by putting point mass and getting infinite likelihood at any datapoint.
        use_grad_checkpointing (bool, optional): Use gradient checkpointing to save memory. Defaults to True.

    def __init__(
        frame_channels: int,
        ar_order: int,
        deterministic_transition: bool,
        encoder_dim: int,
        prenet_type: str,
        prenet_dim: int,
        prenet_n_layers: int,
        prenet_dropout: float,
        prenet_dropout_at_inference: bool,
        memory_rnn_dim: int,
        outputnet_size: List[int],
        flat_start_params: dict,
        std_floor: float,
        use_grad_checkpointing: bool = True,

        self.frame_channels = frame_channels
        self.ar_order = ar_order
        self.deterministic_transition = deterministic_transition
        self.prenet_dim = prenet_dim
        self.memory_rnn_dim = memory_rnn_dim
        self.use_grad_checkpointing = use_grad_checkpointing

        self.transition_model = TransitionModel()
        self.emission_model = EmissionModel()

        assert ar_order > 0, f"AR order must be greater than 0 provided {ar_order}"

        self.ar_order = ar_order
        self.prenet = Prenet(
            in_features=frame_channels * ar_order,
            out_features=[self.prenet_dim for _ in range(prenet_n_layers)],
        self.memory_rnn = nn.LSTMCell(input_size=prenet_dim, hidden_size=memory_rnn_dim)
        self.output_net = Outputnet(
            encoder_dim, memory_rnn_dim, frame_channels, outputnet_size, flat_start_params, std_floor
        self.register_buffer("go_tokens", torch.zeros(ar_order, 1))

    def forward(self, inputs, inputs_len, mels, mel_lens):
        r"""HMM forward algorithm for training uses logarithmic version of Rabiner (1989) forward algorithm.

            inputs (torch.FloatTensor): Encoder outputs
            inputs_len (torch.LongTensor): Encoder output lengths
            mels (torch.FloatTensor): Mel inputs
            mel_lens (torch.LongTensor): Length of mel inputs

            - inputs: (B, T, D_out_enc)
            - inputs_len: (B)
            - mels: (B, D_mel, T_mel)
            - mel_lens: (B)

            log_prob (torch.FloatTensor): Log probability of the sequence
        # Get dimensions of inputs
        batch_size, N, _ = inputs.shape
        T_max = torch.max(mel_lens)
        mels = mels.permute(0, 2, 1)

        # Intialize forward algorithm
        log_state_priors = self._initialize_log_state_priors(inputs)
        log_c, log_alpha_scaled, transition_matrix, means = self._initialize_forward_algorithm_variables(mels, N)

        # Initialize autoregression elements
        ar_inputs = self._add_go_token(mels)
        h_memory, c_memory = self._init_lstm_states(batch_size, self.memory_rnn_dim, mels)

        for t in range(T_max):
            # Process Autoregression
            h_memory, c_memory = self._process_ar_timestep(t, ar_inputs, h_memory, c_memory)
            # Get mean, std and transition vector from decoder for this timestep
            # Note: Gradient checkpointing currently doesn't works with multiple gpus inside a loop
            if self.use_grad_checkpointing and
                mean, std, transition_vector = checkpoint(self.output_net, h_memory, inputs)
                mean, std, transition_vector = self.output_net(h_memory, inputs)

            if t == 0:
                log_alpha_temp = log_state_priors + self.emission_model(mels[:, 0], mean, std, inputs_len)
                log_alpha_temp = self.emission_model(mels[:, t], mean, std, inputs_len) + self.transition_model(
                    log_alpha_scaled[:, t - 1, :], transition_vector, inputs_len
            log_c[:, t] = torch.logsumexp(log_alpha_temp, dim=1)
            log_alpha_scaled[:, t, :] = log_alpha_temp - log_c[:, t].unsqueeze(1)
            transition_matrix[:, t] = transition_vector  # needed for absorption state calculation

            # Save for plotting

        log_c, log_alpha_scaled = self._mask_lengths(mel_lens, log_c, log_alpha_scaled)

        sum_final_log_c = self.get_absorption_state_scaling_factor(
            mel_lens, log_alpha_scaled, inputs_len, transition_matrix

        log_probs = torch.sum(log_c, dim=1) + sum_final_log_c

        return log_probs, log_alpha_scaled, transition_matrix, means

    def _mask_lengths(mel_lens, log_c, log_alpha_scaled):
        Mask the lengths of the forward variables so that the variable lenghts
        do not contribute in the loss calculation
            mel_inputs (torch.FloatTensor): (batch, T, frame_channels)
            mel_inputs_lengths (torch.IntTensor): (batch)
            log_c (torch.FloatTensor): (batch, T)
            log_c (torch.FloatTensor) : scaled probabilities (batch, T)
            log_alpha_scaled (torch.FloatTensor): forward probabilities (batch, T, N)
        mask_log_c = sequence_mask(mel_lens)
        log_c = log_c * mask_log_c
        mask_log_alpha_scaled = mask_log_c.unsqueeze(2)
        log_alpha_scaled = log_alpha_scaled * mask_log_alpha_scaled
        return log_c, log_alpha_scaled

    def _process_ar_timestep(
        Process autoregression in timestep
        1. At a specific t timestep
        2. Perform data dropout if applied (we did not use it)
        3. Run the autoregressive frame through the prenet (has dropout)
        4. Run the prenet output through the post prenet rnn

            t (int): mel-spec timestep
            ar_inputs (torch.FloatTensor): go-token appended mel-spectrograms
                - shape: (b, D_out, T_out)
            h_post_prenet (torch.FloatTensor): previous timestep rnn hidden state
                - shape: (b, memory_rnn_dim)
            c_post_prenet (torch.FloatTensor): previous timestep rnn cell state
                - shape: (b, memory_rnn_dim)

            h_post_prenet (torch.FloatTensor): rnn hidden state of the current timestep
            c_post_prenet (torch.FloatTensor): rnn cell state of the current timestep
        prenet_input = ar_inputs[:, t : t + self.ar_order].flatten(1)
        memory_inputs = self.prenet(prenet_input)
        h_memory, c_memory = self.memory_rnn(memory_inputs, (h_memory, c_memory))
        return h_memory, c_memory

    def _add_go_token(self, mel_inputs):
        """Append the go token to create the autoregressive input
            mel_inputs (torch.FloatTensor): (batch_size, T, n_mel_channel)
            ar_inputs (torch.FloatTensor): (batch_size, T, n_mel_channel)
        batch_size, T, _ = mel_inputs.shape
        go_tokens = self.go_tokens.unsqueeze(0).expand(batch_size, self.ar_order, self.frame_channels)
        ar_inputs =, mel_inputs), dim=1)[:, :T]
        return ar_inputs

    def _initialize_forward_algorithm_variables(mel_inputs, N):
        r"""Initialize placeholders for forward algorithm variables, to use a stable
                version we will use log_alpha_scaled and the scaling constant

            mel_inputs (torch.FloatTensor): (b, T_max, frame_channels)
            N (int): number of states
            log_c (torch.FloatTensor): Scaling constant (b, T_max)
        b, T_max, _ = mel_inputs.shape
        log_alpha_scaled = mel_inputs.new_zeros((b, T_max, N))
        log_c = mel_inputs.new_zeros(b, T_max)
        transition_matrix = mel_inputs.new_zeros((b, T_max, N))

        # Saving for plotting later, will not have gradient tapes
        means = []
        return log_c, log_alpha_scaled, transition_matrix, means

    def _init_lstm_states(batch_size, hidden_state_dim, device_tensor):
        Initialize Hidden and Cell states for LSTM Cell

            batch_size (Int): batch size
            hidden_state_dim (Int): dimensions of the h and c
            device_tensor (torch.FloatTensor): useful for the device and type

            (torch.FloatTensor): shape (batch_size, hidden_state_dim)
                can be hidden state for LSTM
            (torch.FloatTensor): shape (batch_size, hidden_state_dim)
                can be the cell state for LSTM
        return (
            device_tensor.new_zeros(batch_size, hidden_state_dim),
            device_tensor.new_zeros(batch_size, hidden_state_dim),

    def get_absorption_state_scaling_factor(self, mels_len, log_alpha_scaled, inputs_len, transition_vector):
        """Returns the final scaling factor of absorption state

            mels_len (torch.IntTensor): Input size of mels to
                    get the last timestep of log_alpha_scaled
            log_alpha_scaled (torch.FloatTEnsor): State probabilities
            text_lengths (torch.IntTensor): length of the states to
                    mask the values of states lengths
                    Useful when the batch has very different lengths,
                    when the length of an observation is less than
                    the number of max states, then the log alpha after
                    the state value is filled with -infs. So we mask
                    those values so that it only consider the states
                    which are needed for that length
            transition_vector (torch.FloatTensor): transtiion vector for each state per timestep

            - mels_len: (batch_size)
            - log_alpha_scaled: (batch_size, N, T)
            - text_lengths: (batch_size)
            - transition_vector: (batch_size, N, T)

            sum_final_log_c (torch.FloatTensor): (batch_size)

        N = torch.max(inputs_len)
        max_inputs_len = log_alpha_scaled.shape[2]
        state_lengths_mask = sequence_mask(inputs_len, max_len=max_inputs_len)

        last_log_alpha_scaled_index = (
            (mels_len - 1).unsqueeze(-1).expand(-1, N).unsqueeze(1)
        )  # Batch X Hidden State Size
        last_log_alpha_scaled = torch.gather(log_alpha_scaled, 1, last_log_alpha_scaled_index).squeeze(1)
        last_log_alpha_scaled = last_log_alpha_scaled.masked_fill(~state_lengths_mask, -float("inf"))

        last_transition_vector = torch.gather(transition_vector, 1, last_log_alpha_scaled_index).squeeze(1)
        last_transition_probability = torch.sigmoid(last_transition_vector)
        log_probability_of_transitioning = OverflowUtils.log_clamped(last_transition_probability)

        last_transition_probability_index = self.get_mask_for_last_item(inputs_len, inputs_len.device)
        log_probability_of_transitioning = log_probability_of_transitioning.masked_fill(
            ~last_transition_probability_index, -float("inf")
        final_log_c = last_log_alpha_scaled + log_probability_of_transitioning

        # If the length of the mel is less than the number of states it will select the -inf values leading to nan gradients
        # Ideally, we should clean the dataset otherwise this is a little hack uncomment the line below
        final_log_c = final_log_c.clamp(min=torch.finfo(final_log_c.dtype).min)

        sum_final_log_c = torch.logsumexp(final_log_c, dim=1)
        return sum_final_log_c

    def get_mask_for_last_item(lengths, device, out_tensor=None):
        """Returns n-1 mask for the last item in the sequence.

            lengths (torch.IntTensor): lengths in a batch
            device (str, optional): Defaults to "cpu".
            out_tensor (torch.Tensor, optional): uses the memory of a specific tensor.
                Defaults to None.

            - Shape: :math:`(b, max_len)`
        max_len = torch.max(lengths).item()
        ids = (
            torch.arange(0, max_len, device=device) if out_tensor is None else torch.arange(0, max_len, out=out_tensor)
        mask = ids == lengths.unsqueeze(1) - 1
        return mask

    def inference(
        inputs: torch.FloatTensor,
        input_lens: torch.LongTensor,
        sampling_temp: float,
        max_sampling_time: int,
        duration_threshold: float,
        """Inference from autoregressive neural HMM

            inputs (torch.FloatTensor): input states
                - shape: :math:`(b, T, d)`
            input_lens (torch.LongTensor): input state lengths
                - shape: :math:`(b)`
            sampling_temp (float): sampling temperature
            max_sampling_temp (int): max sampling temperature
            duration_threshold (float): duration threshold to switch to next state
                - Use this to change the spearking rate of the synthesised audio

        b = inputs.shape[0]
        outputs = {
            "hmm_outputs": [],
            "hmm_outputs_len": [],
            "alignments": [],
            "input_parameters": [],
            "output_parameters": [],
        for i in range(b):
            neural_hmm_outputs, states_travelled, input_parameters, output_parameters = self.sample(
                inputs[i : i + 1], input_lens[i], sampling_temp, max_sampling_time, duration_threshold


        outputs["hmm_outputs"] = nn.utils.rnn.pad_sequence(outputs["hmm_outputs"], batch_first=True)
        outputs["hmm_outputs_len"] = torch.tensor(
            outputs["hmm_outputs_len"], dtype=input_lens.dtype, device=input_lens.device
        return outputs

    def sample(self, inputs, input_lens, sampling_temp, max_sampling_time, duration_threshold):
        """Samples an output from the parameter models

            inputs (torch.FloatTensor): input states
                - shape: :math:`(1, T, d)`
            input_lens (torch.LongTensor): input state lengths
                - shape: :math:`(1)`
            sampling_temp (float): sampling temperature
            max_sampling_time (int): max sampling time
            duration_threshold (float): duration threshold to switch to next state

            outputs (torch.FloatTensor): Output Observations
                - Shape: :math:`(T, output_dim)`
            states_travelled (list[int]): Hidden states travelled
                - Shape: :math:`(T)`
            input_parameters (list[torch.FloatTensor]): Input parameters
            output_parameters (list[torch.FloatTensor]): Output parameters
        states_travelled, outputs, t = [], [], 0

        # Sample initial state
        current_state = 0

        # Prepare autoregression
        prenet_input = self.go_tokens.unsqueeze(0).expand(1, self.ar_order, self.frame_channels)
        h_memory, c_memory = self._init_lstm_states(1, self.memory_rnn_dim, prenet_input)

        input_parameter_values = []
        output_parameter_values = []
        quantile = 1
        while True:
            memory_input = self.prenet(prenet_input.flatten(1).unsqueeze(0))
            # will be 1 while sampling
            h_memory, c_memory = self.memory_rnn(memory_input.squeeze(0), (h_memory, c_memory))

            z_t = inputs[:, current_state].unsqueeze(0)  # Add fake time dimension
            mean, std, transition_vector = self.output_net(h_memory, z_t)

            transition_probability = torch.sigmoid(transition_vector.flatten())
            staying_probability = torch.sigmoid(-transition_vector.flatten())

            # Save for plotting
            input_parameter_values.append([prenet_input, current_state])
            output_parameter_values.append([mean, std, transition_probability])

            x_t = self.emission_model.sample(mean, std, sampling_temp=sampling_temp)

            # Prepare autoregressive input for next iteration
            prenet_input =, x_t), dim=1)[:, 1:]


            transition_matrix =, transition_probability))
            quantile *= staying_probability
            if not self.deterministic_transition:
                switch = transition_matrix.multinomial(1)[0].item()
                switch = quantile < duration_threshold

            if switch:
                current_state += 1
                quantile = 1


            if (current_state == input_lens) or (max_sampling_time and t == max_sampling_time - 1):

            t += 1

        return (
            torch.stack(outputs, dim=0),

    def _initialize_log_state_priors(text_embeddings):
        """Creates the log pi in forward algorithm.

            text_embeddings (torch.FloatTensor): used to create the log pi
                    on current device

            - text_embeddings: (B, T, D_out_enc)
        N = text_embeddings.shape[1]
        log_state_priors = text_embeddings.new_full([N], -float("inf"))
        log_state_priors[0] = 0.0
        return log_state_priors

class TransitionModel(nn.Module):
    """Transition Model of the HMM, it represents the probability of transitioning
    form current state to all other states"""

    def forward(self, log_alpha_scaled, transition_vector, inputs_len):  # pylint: disable=no-self-use
        product of the past state with transitional probabilities in log space

            log_alpha_scaled (torch.Tensor): Multiply previous timestep's alphas by
                        transition matrix (in log domain)
                - shape: (batch size, N)
            transition_vector (torch.tensor): transition vector for each state
                - shape: (N)
            inputs_len (int tensor): Lengths of states in a batch
                - shape: (batch)

            out (torch.FloatTensor): log probability of transitioning to each state
        transition_p = torch.sigmoid(transition_vector)
        staying_p = torch.sigmoid(-transition_vector)

        log_staying_probability = OverflowUtils.log_clamped(staying_p)
        log_transition_probability = OverflowUtils.log_clamped(transition_p)

        staying = log_alpha_scaled + log_staying_probability
        leaving = log_alpha_scaled + log_transition_probability
        leaving = leaving.roll(1, dims=1)
        leaving[:, 0] = -float("inf")
        inputs_len_mask = sequence_mask(inputs_len)
        out = OverflowUtils.logsumexp(torch.stack((staying, leaving), dim=2), dim=2)
        out = out.masked_fill(~inputs_len_mask, -float("inf"))  # There are no states to contribute to the loss
        return out

class EmissionModel(nn.Module):
    """Emission Model of the HMM, it represents the probability of
    emitting an observation based on the current state"""

    def __init__(self) -> None:
        self.distribution_function: tdist.Distribution = tdist.normal.Normal

    def sample(self, means, stds, sampling_temp):
        return self.distribution_function(means, stds * sampling_temp).sample() if sampling_temp > 0 else means

    def forward(self, x_t, means, stds, state_lengths):
        r"""Calculates the log probability of the the given data (x_t)
            being observed from states with given means and stds
            x_t (float tensor) : observation at current time step
                - shape: (batch, feature_dim)
            means (float tensor): means of the distributions of hidden states
                - shape: (batch, hidden_state, feature_dim)
            stds (float tensor): standard deviations of the distributions of the hidden states
                - shape: (batch, hidden_state, feature_dim)
            state_lengths (int tensor): Lengths of states in a batch
                - shape: (batch)

            out (float tensor): observation log likelihoods,
                                    expressing the probability of an observation
                being generated from a state i
                shape: (batch, hidden_state)
        emission_dists = self.distribution_function(means, stds)
        out = emission_dists.log_prob(x_t.unsqueeze(1))
        state_lengths_mask = sequence_mask(state_lengths).unsqueeze(2)
        out = torch.sum(out * state_lengths_mask, dim=2)
        return out