from torch import nn


class ZeroTemporalPad(nn.Module):
    """Pad sequences to equal lentgh in the temporal dimension"""

    def __init__(self, kernel_size, dilation):
        super().__init__()
        total_pad = dilation * (kernel_size - 1)
        begin = total_pad // 2
        end = total_pad - begin
        self.pad_layer = nn.ZeroPad2d((0, 0, begin, end))

    def forward(self, x):
        return self.pad_layer(x)


class Conv1dBN(nn.Module):
    """1d convolutional with batch norm.
    conv1d -> relu -> BN blocks.

    Note:
        Batch normalization is applied after ReLU regarding the original implementation.

    Args:
        in_channels (int): number of input channels.
        out_channels (int): number of output channels.
        kernel_size (int): kernel size for convolutional filters.
        dilation (int): dilation for convolution layers.
    """

    def __init__(self, in_channels, out_channels, kernel_size, dilation):
        super().__init__()
        padding = dilation * (kernel_size - 1)
        pad_s = padding // 2
        pad_e = padding - pad_s
        self.conv1d = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation)
        self.pad = nn.ZeroPad2d((pad_s, pad_e, 0, 0))  # uneven left and right padding
        self.norm = nn.BatchNorm1d(out_channels)

    def forward(self, x):
        o = self.conv1d(x)
        o = self.pad(o)
        o = nn.functional.relu(o)
        o = self.norm(o)
        return o


class Conv1dBNBlock(nn.Module):
    """1d convolutional block with batch norm. It is a set of conv1d -> relu -> BN blocks.

    Args:
        in_channels (int): number of input channels.
        out_channels (int): number of output channels.
        hidden_channels (int): number of inner convolution channels.
        kernel_size (int): kernel size for convolutional filters.
        dilation (int): dilation for convolution layers.
        num_conv_blocks (int, optional): number of convolutional blocks. Defaults to 2.
    """

    def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation, num_conv_blocks=2):
        super().__init__()
        self.conv_bn_blocks = []
        for idx in range(num_conv_blocks):
            layer = Conv1dBN(
                in_channels if idx == 0 else hidden_channels,
                out_channels if idx == (num_conv_blocks - 1) else hidden_channels,
                kernel_size,
                dilation,
            )
            self.conv_bn_blocks.append(layer)
        self.conv_bn_blocks = nn.Sequential(*self.conv_bn_blocks)

    def forward(self, x):
        """
        Shapes:
            x: (B, D, T)
        """
        return self.conv_bn_blocks(x)


class ResidualConv1dBNBlock(nn.Module):
    """Residual Convolutional Blocks with BN
    Each block has 'num_conv_block' conv layers and 'num_res_blocks' such blocks are connected
    with residual connections.

    conv_block = (conv1d -> relu -> bn) x 'num_conv_blocks'
    residuak_conv_block =  (x -> conv_block ->  + ->) x 'num_res_blocks'
                            ' - - - - - - - - - ^
    Args:
        in_channels (int): number of input channels.
        out_channels (int): number of output channels.
        hidden_channels (int): number of inner convolution channels.
        kernel_size (int): kernel size for convolutional filters.
        dilations (list): dilations for each convolution layer.
        num_res_blocks (int, optional): number of residual blocks. Defaults to 13.
        num_conv_blocks (int, optional): number of convolutional blocks in each residual block. Defaults to 2.
    """

    def __init__(
        self, in_channels, out_channels, hidden_channels, kernel_size, dilations, num_res_blocks=13, num_conv_blocks=2
    ):
        super().__init__()
        assert len(dilations) == num_res_blocks
        self.res_blocks = nn.ModuleList()
        for idx, dilation in enumerate(dilations):
            block = Conv1dBNBlock(
                in_channels if idx == 0 else hidden_channels,
                out_channels if (idx + 1) == len(dilations) else hidden_channels,
                hidden_channels,
                kernel_size,
                dilation,
                num_conv_blocks,
            )
            self.res_blocks.append(block)

    def forward(self, x, x_mask=None):
        if x_mask is None:
            x_mask = 1.0
        o = x * x_mask
        for block in self.res_blocks:
            res = o
            o = block(o)
            o = o + res
            if x_mask is not None:
                o = o * x_mask
        return o