from typing import Callable, Sequence, Type, Union

import numpy as np
import torch
import torch.nn as nn

ModuleFactory = Union[Type[nn.Module], Callable[[], nn.Module]]


class FeedForwardModule(nn.Module):

    def __init__(self) -> None:
        super().__init__()
        self.net = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class Residual(nn.Module):

    def __init__(self, module: nn.Module) -> None:
        super().__init__()
        self.module = module

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.module(x) + x


class DilatedConvolutionalUnit(FeedForwardModule):

    def __init__(
            self,
            hidden_dim: int,
            dilation: int,
            kernel_size: int,
            activation: ModuleFactory,
            normalization: Callable[[nn.Module],
                                    nn.Module] = lambda x: x) -> None:
        super().__init__()
        self.net = nn.Sequential(
            activation(),
            normalization(
                nn.Conv1d(
                    in_channels=hidden_dim,
                    out_channels=hidden_dim,
                    kernel_size=kernel_size,
                    dilation=dilation,
                    padding=((kernel_size - 1) * dilation) // 2,
                )),
            activation(),
            nn.Conv1d(in_channels=hidden_dim,
                      out_channels=hidden_dim,
                      kernel_size=1),
        )


class UpsamplingUnit(FeedForwardModule):

    def __init__(
            self,
            input_dim: int,
            output_dim: int,
            stride: int,
            activation: ModuleFactory,
            normalization: Callable[[nn.Module],
                                    nn.Module] = lambda x: x) -> None:
        super().__init__()
        self.net = nn.Sequential(
            activation(),
            normalization(
                nn.ConvTranspose1d(
                    in_channels=input_dim,
                    out_channels=output_dim,
                    kernel_size=2 * stride,
                    stride=stride,
                    padding=stride // 2+ stride % 2,
                    output_padding=1 if stride % 2 != 0 else 0
                )))


class DownsamplingUnit(FeedForwardModule):

    def __init__(
            self,
            input_dim: int,
            output_dim: int,
            stride: int,
            activation: ModuleFactory,
            normalization: Callable[[nn.Module],
                                    nn.Module] = lambda x: x) -> None:
        super().__init__()
        self.net = nn.Sequential(
            activation(),
            normalization(
                nn.Conv1d(
                    in_channels=input_dim,
                    out_channels=output_dim,
                    kernel_size=2 * stride,
                    stride=stride,
                    padding= stride // 2+ stride % 2,
                     
                )))


class DilatedResidualEncoder(FeedForwardModule):

    def __init__(
            self,
            capacity: int,
            dilated_unit: Type[DilatedConvolutionalUnit],
            downsampling_unit: Type[DownsamplingUnit],
            ratios: Sequence[int],
            dilations: Union[Sequence[int], Sequence[Sequence[int]]],
            pre_network_conv: Type[nn.Conv1d],
            post_network_conv: Type[nn.Conv1d],
            normalization: Callable[[nn.Module],
                                    nn.Module] = lambda x: x) -> None:
        super().__init__()
        channels = capacity * 2**np.arange(len(ratios) + 1)

        dilations_list = self.normalize_dilations(dilations, ratios)

        net = [normalization(pre_network_conv(out_channels=channels[0]))]

        for ratio, dilations, input_dim, output_dim in zip(
                ratios, dilations_list, channels[:-1], channels[1:]):
            for dilation in dilations:
                net.append(Residual(dilated_unit(input_dim, dilation)))
            net.append(downsampling_unit(input_dim, output_dim, ratio))

        net.append(post_network_conv(in_channels=output_dim))

        self.net = nn.Sequential(*net)

    @staticmethod
    def normalize_dilations(dilations: Union[Sequence[int],
                                             Sequence[Sequence[int]]],
                            ratios: Sequence[int]):
        if isinstance(dilations[0], int):
            dilations = [dilations for _ in ratios]
        return dilations


class DilatedResidualDecoder(FeedForwardModule):

    def __init__(
            self,
            capacity: int,
            dilated_unit: Type[DilatedConvolutionalUnit],
            upsampling_unit: Type[UpsamplingUnit],
            ratios: Sequence[int],
            dilations: Union[Sequence[int], Sequence[Sequence[int]]],
            pre_network_conv: Type[nn.Conv1d],
            post_network_conv: Type[nn.Conv1d],
            normalization: Callable[[nn.Module],
                                    nn.Module] = lambda x: x) -> None:
        super().__init__()
        channels = capacity * 2**np.arange(len(ratios) + 1)
        channels = channels[::-1]

        dilations_list = self.normalize_dilations(dilations, ratios)
        dilations_list = dilations_list[::-1]

        net = [pre_network_conv(out_channels=channels[0])]

        for ratio, dilations, input_dim, output_dim in zip(
                ratios, dilations_list, channels[:-1], channels[1:]):
            net.append(upsampling_unit(input_dim, output_dim, ratio))
            for dilation in dilations:
                net.append(Residual(dilated_unit(output_dim, dilation)))

        net.append(normalization(post_network_conv(in_channels=output_dim)))

        self.net = nn.Sequential(*net)

    @staticmethod
    def normalize_dilations(dilations: Union[Sequence[int],
                                             Sequence[Sequence[int]]],
                            ratios: Sequence[int]):
        if isinstance(dilations[0], int):
            dilations = [dilations for _ in ratios]
        return dilations