from typing import List, Tuple

import torch
from torch import Tensor, nn
from torch.nn import Conv2d, Module
import torch.nn.functional as F
from torch.nn.utils import spectral_norm, weight_norm

from models.config import HifiGanPretrainingConfig

from .utils import get_padding

# Leaky ReLU slope
LRELU_SLOPE = HifiGanPretrainingConfig.lReLU_slope


class DiscriminatorP(Module):
    def __init__(
        self,
        period: int,
        kernel_size: int = 5,
        stride: int = 3,
        use_spectral_norm: bool = False,
    ):
        r"""Initialize the DiscriminatorP module.

        Args:
            period (int): The period for the discriminator.
            kernel_size (int, optional): The kernel size for the convolutional layers. Defaults to 5.
            stride (int, optional): The stride for the convolutional layers. Defaults to 3.
            use_spectral_norm (bool, optional): Whether to use spectral normalization. Defaults to False.
        """
        super().__init__()
        self.period = period
        norm_f = weight_norm if not use_spectral_norm else spectral_norm
        self.convs = nn.ModuleList(
            [
                norm_f(
                    Conv2d(
                        1,
                        32,
                        (kernel_size, 1),
                        (stride, 1),
                        padding=(get_padding(5, 1), 0),
                    ),
                ),
                norm_f(
                    Conv2d(
                        32,
                        128,
                        (kernel_size, 1),
                        (stride, 1),
                        padding=(get_padding(5, 1), 0),
                    ),
                ),
                norm_f(
                    Conv2d(
                        128,
                        512,
                        (kernel_size, 1),
                        (stride, 1),
                        padding=(get_padding(5, 1), 0),
                    ),
                ),
                norm_f(
                    Conv2d(
                        512,
                        1024,
                        (kernel_size, 1),
                        (stride, 1),
                        padding=(get_padding(5, 1), 0),
                    ),
                ),
                norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
            ],
        )
        self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))

    def forward(self, x: Tensor) -> Tuple[Tensor, List[Tensor]]:
        r"""Forward pass of the DiscriminatorP module.

        Args:
            x (Tensor): The input tensor.

        Returns:
            Tuple[Tensor, List[Tensor]]: The output tensor and a list of feature maps.
        """
        fmap = []

        # 1d to 2d
        b, c, t = x.shape
        if t % self.period != 0:  # pad first
            n_pad = self.period - (t % self.period)
            x = F.pad(x, (0, n_pad), "reflect")
            t = t + n_pad
        x = x.view(b, c, t // self.period, self.period)

        for layer in self.convs:
            x = layer(x)
            x = F.leaky_relu(x, LRELU_SLOPE)
            fmap.append(x)
        x = self.conv_post(x)
        fmap.append(x)
        x = torch.flatten(x, 1, -1)

        return x, fmap


class MultiPeriodDiscriminator(torch.nn.Module):
    def __init__(self):
        r"""Initialize the MultiPeriodDiscriminator module."""
        super().__init__()
        self.discriminators = nn.ModuleList(
            [
                DiscriminatorP(2),
                DiscriminatorP(3),
                DiscriminatorP(5),
                DiscriminatorP(7),
                DiscriminatorP(11),
            ],
        )

    def forward(
        self,
        y: Tensor,
        y_hat: Tensor,
    ) -> Tuple[
        List[torch.Tensor],
        List[torch.Tensor],
        List[torch.Tensor],
        List[torch.Tensor],
    ]:
        r"""Forward pass of the MultiPeriodDiscriminator module.

        Args:
            y (torch.Tensor): The real audio tensor.
            y_hat (torch.Tensor): The generated audio tensor.

        Returns:
            Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
            A tuple containing lists of discriminator outputs and feature maps for real and generated audio.
        """
        y_d_rs = []
        y_d_gs = []
        fmap_rs = []
        fmap_gs = []

        for _, discriminator in enumerate(self.discriminators):
            y_d_r, fmap_r = discriminator(y)
            y_d_g, fmap_g = discriminator(y_hat)

            y_d_rs.append(y_d_r)
            fmap_rs.append(fmap_r)
            y_d_gs.append(y_d_g)
            fmap_gs.append(fmap_g)

        return y_d_rs, y_d_gs, fmap_rs, fmap_gs