from torch import nn

from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock
from TTS.tts.layers.generic.transformer import FFTransformerBlock
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer


class RelativePositionTransformerEncoder(nn.Module):
    """Speedy speech encoder built on Transformer with Relative Position encoding.

    TODO: Integrate speaker conditioning vector.

    Args:
        in_channels (int): number of input channels.
        out_channels (int): number of output channels.
        hidden_channels (int): number of hidden channels
        params (dict): dictionary for residual convolutional blocks.
    """

    def __init__(self, in_channels, out_channels, hidden_channels, params):
        super().__init__()
        self.prenet = ResidualConv1dBNBlock(
            in_channels,
            hidden_channels,
            hidden_channels,
            kernel_size=5,
            num_res_blocks=3,
            num_conv_blocks=1,
            dilations=[1, 1, 1],
        )
        self.rel_pos_transformer = RelativePositionTransformer(hidden_channels, out_channels, hidden_channels, **params)

    def forward(self, x, x_mask=None, g=None):  # pylint: disable=unused-argument
        if x_mask is None:
            x_mask = 1
        o = self.prenet(x) * x_mask
        o = self.rel_pos_transformer(o, x_mask)
        return o


class ResidualConv1dBNEncoder(nn.Module):
    """Residual Convolutional Encoder as in the original Speedy Speech paper

    TODO: Integrate speaker conditioning vector.

    Args:
        in_channels (int): number of input channels.
        out_channels (int): number of output channels.
        hidden_channels (int): number of hidden channels
        params (dict): dictionary for residual convolutional blocks.
    """

    def __init__(self, in_channels, out_channels, hidden_channels, params):
        super().__init__()
        self.prenet = nn.Sequential(nn.Conv1d(in_channels, hidden_channels, 1), nn.ReLU())
        self.res_conv_block = ResidualConv1dBNBlock(hidden_channels, hidden_channels, hidden_channels, **params)

        self.postnet = nn.Sequential(
            *[
                nn.Conv1d(hidden_channels, hidden_channels, 1),
                nn.ReLU(),
                nn.BatchNorm1d(hidden_channels),
                nn.Conv1d(hidden_channels, out_channels, 1),
            ]
        )

    def forward(self, x, x_mask=None, g=None):  # pylint: disable=unused-argument
        if x_mask is None:
            x_mask = 1
        o = self.prenet(x) * x_mask
        o = self.res_conv_block(o, x_mask)
        o = self.postnet(o + x) * x_mask
        return o * x_mask


class Encoder(nn.Module):
    # pylint: disable=dangerous-default-value
    """Factory class for Speedy Speech encoder enables different encoder types internally.

    Args:
        num_chars (int): number of characters.
        out_channels (int): number of output channels.
        in_hidden_channels (int): input and hidden channels. Model keeps the input channels for the intermediate layers.
        encoder_type (str): encoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'.
        encoder_params (dict): model parameters for specified encoder type.
        c_in_channels (int): number of channels for conditional input.

    Note:
        Default encoder_params to be set in config.json...

        ```python
        # for 'relative_position_transformer'
        encoder_params={
            'hidden_channels_ffn': 128,
            'num_heads': 2,
            "kernel_size": 3,
            "dropout_p": 0.1,
            "num_layers": 6,
            "rel_attn_window_size": 4,
            "input_length": None
        },

        # for 'residual_conv_bn'
        encoder_params = {
            "kernel_size": 4,
            "dilations": 4 * [1, 2, 4] + [1],
            "num_conv_blocks": 2,
            "num_res_blocks": 13
        }

        # for 'fftransformer'
        encoder_params = {
            "hidden_channels_ffn": 1024 ,
            "num_heads": 2,
            "num_layers": 6,
            "dropout_p": 0.1
        }
        ```
    """

    def __init__(
        self,
        in_hidden_channels,
        out_channels,
        encoder_type="residual_conv_bn",
        encoder_params={"kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13},
        c_in_channels=0,
    ):
        super().__init__()
        self.out_channels = out_channels
        self.in_channels = in_hidden_channels
        self.hidden_channels = in_hidden_channels
        self.encoder_type = encoder_type
        self.c_in_channels = c_in_channels

        # init encoder
        if encoder_type.lower() == "relative_position_transformer":
            # text encoder
            # pylint: disable=unexpected-keyword-arg
            self.encoder = RelativePositionTransformerEncoder(
                in_hidden_channels, out_channels, in_hidden_channels, encoder_params
            )
        elif encoder_type.lower() == "residual_conv_bn":
            self.encoder = ResidualConv1dBNEncoder(in_hidden_channels, out_channels, in_hidden_channels, encoder_params)
        elif encoder_type.lower() == "fftransformer":
            assert (
                in_hidden_channels == out_channels
            ), "[!] must be `in_channels` == `out_channels` when encoder type is 'fftransformer'"
            # pylint: disable=unexpected-keyword-arg
            self.encoder = FFTransformerBlock(in_hidden_channels, **encoder_params)
        else:
            raise NotImplementedError(" [!] unknown encoder type.")

    def forward(self, x, x_mask, g=None):  # pylint: disable=unused-argument
        """
        Shapes:
            x: [B, C, T]
            x_mask: [B, 1, T]
            g: [B, C, 1]
        """
        o = self.encoder(x, x_mask)
        return o * x_mask