import pytorch_lightning as pl

import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Dict, List, Optional, OrderedDict, Tuple


class Discriminator(nn.Module):
    def __init__(
        self,
        hidden_size: Optional[int] = 64,
        channels: Optional[int] = 3,
        kernel_size: Optional[int] = 4,
        stride: Optional[int] = 2,
        padding: Optional[int] = 1,
        negative_slope: Optional[float] = 0.2,
        bias: Optional[bool] = False,
    ):
        """
        Initializes the discriminator.

        Parameters
        ----------
        hidden_size : int, optional
            The input size. (the default is 64)
        channels : int, optional
            The number of channels. (default: 3)
        kernel_size : int, optional
            The kernal size. (default: 4)
        stride : int, optional
            The stride. (default: 2)
        padding : int, optional
            The padding. (default: 1)
        negative_slope : float, optional
            The negative slope. (default: 0.2)
        bias : bool, optional
            Whether to use bias. (default: False)
        """
        super().__init__()
        self.hidden_size = hidden_size
        self.channels = channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.negative_slope = negative_slope
        self.bias = bias

        self.model = nn.Sequential(
            nn.utils.spectral_norm(
                nn.Conv2d(
                    self.channels, self.hidden_size, 
                    kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
                ),
            ),
            nn.LeakyReLU(self.negative_slope, inplace=True),

            nn.utils.spectral_norm(
                nn.Conv2d(
                    hidden_size, hidden_size * 2, 
                    kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
                ),
            ),
            nn.BatchNorm2d(hidden_size * 2),
            nn.LeakyReLU(self.negative_slope, inplace=True),

            nn.utils.spectral_norm(
                nn.Conv2d(    
                    hidden_size * 2, hidden_size * 4,
                    kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
                ),
            ),
            nn.BatchNorm2d(hidden_size * 4),
            nn.LeakyReLU(self.negative_slope, inplace=True),

            nn.utils.spectral_norm(
                nn.Conv2d(
                    hidden_size * 4, hidden_size * 8, 
                    kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
                ),
            ),
            nn.BatchNorm2d(hidden_size * 8),
            nn.LeakyReLU(self.negative_slope, inplace=True),

            nn.utils.spectral_norm(
                nn.Conv2d(hidden_size * 8, 1, kernel_size=4, stride=1, padding=0, bias=self.bias), # output size: (1, 1, 1)
            ),
            nn.Flatten(),
            nn.Sigmoid(),
        )

    
    def forward(self, input_img: torch.Tensor) -> torch.Tensor:
        """
        Forward propagation.

        Parameters
        ----------
        input_img : torch.Tensor
            The input image.

        Returns
        -------
        torch.Tensor
            The output.
        """
        logits = self.model(input_img)
        return logits


class Generator(nn.Module):
    def __init__(
        self,
        hidden_size: Optional[int] = 64,
        latent_size: Optional[int] = 128,
        channels: Optional[int] = 3,
        kernel_size: Optional[int] = 4,
        stride: Optional[int] = 2,
        padding: Optional[int] = 1,
        bias: Optional[bool] = False,
    ):
        """
        Initializes the generator.

        Parameters
        ----------
        hidden_size : int, optional
            The hidden size. (default: 64)
        latent_size : int, optional
            The latent size. (default: 128)
        channels : int, optional
            The number of channels. (default: 3)
        kernel_size : int, optional
            The kernel size. (default: 4)
        stride : int, optional
            The stride. (default: 2)
        padding : int, optional
            The padding. (default: 1)
        bias : bool, optional
            Whether to use bias. (default: False)
        """
        super().__init__()
        self.hidden_size = hidden_size
        self.latent_size = latent_size
        self.channels = channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.bias = bias

        self.model = nn.Sequential(
            nn.ConvTranspose2d(
                self.latent_size, self.hidden_size * 8, 
                kernel_size=self.kernel_size, stride=1, padding=0, bias=self.bias
            ),
            nn.BatchNorm2d(self.hidden_size * 8),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(
                self.hidden_size * 8, self.hidden_size * 4, 
                kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
            ),
            nn.BatchNorm2d(self.hidden_size * 4),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(
                self.hidden_size * 4, self.hidden_size * 2, 
                kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
            ),
            nn.BatchNorm2d(self.hidden_size * 2),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(
                self.hidden_size * 2, self.hidden_size, 
                kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
            ),
            nn.BatchNorm2d(self.hidden_size),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(
                self.hidden_size, self.channels, 
                kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
            ),
            nn.Tanh() # output size: (channels, 64, 64)
        )

    
    def forward(self, input_noise: torch.Tensor) -> torch.Tensor:
        """
        Forward propagation.

        Parameters
        ----------
        input_noise : torch.Tensor
            The input image.

        Returns
        -------
        torch.Tensor
            The output.
        """
        fake_img = self.model(input_noise)
        return fake_img


class DocuGAN(pl.LightningModule):
    def __init__(
        self,
        hidden_size: Optional[int] = 64,
        latent_size: Optional[int] = 128,
        num_channel: Optional[int] = 3,
        learning_rate: Optional[float] = 0.0002,
        batch_size: Optional[int] = 128,
        bias1: Optional[float] = 0.5,
        bias2: Optional[float] = 0.999,
    ):
        """
        Initializes the LightningGan.

        Parameters
        ----------
        hidden_size : int, optional
            The hidden size. (default: 64)
        latent_size : int, optional
            The latent size. (default: 128)
        num_channel : int, optional
            The number of channels. (default: 3)
        learning_rate : float, optional
            The learning rate. (default: 0.0002)
        batch_size : int, optional
            The batch size. (default: 128)
        bias1 : float, optional
            The bias1. (default: 0.5)
        bias2 : float, optional
            The bias2. (default: 0.999)
        """
        super().__init__()
        self.hidden_size = hidden_size
        self.latent_size = latent_size
        self.num_channel = num_channel
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.bias1 = bias1
        self.bias2 = bias2
        self.criterion = nn.BCELoss()
        self.validation = torch.randn(self.batch_size, self.latent_size, 1, 1)
        self.save_hyperparameters()

        self.generator = Generator(
            latent_size=self.latent_size, channels=self.num_channel, hidden_size=self.hidden_size
        )
        self.generator.apply(self.weights_init)
        
        self.discriminator = Discriminator(channels=self.num_channel, hidden_size=self.hidden_size)
        self.discriminator.apply(self.weights_init)

        # self.model = InceptionV3() # For FID metric


    def weights_init(self, m: nn.Module) -> None:
        """
        Initializes the weights.

        Parameters
        ----------
        m : nn.Module
            The module.
        """
        classname = m.__class__.__name__
        if classname.find("Conv") != -1:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find("BatchNorm") != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)

    
    def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List]:
        """
        Configures the optimizers.

        Returns
        -------
        Tuple[List[torch.optim.Optimizer], List]
            The optimizers and the LR schedulers.
        """
        opt_generator = torch.optim.Adam(
            self.generator.parameters(), lr=self.learning_rate, betas=(self.bias1, self.bias2)
        )
        opt_discriminator = torch.optim.Adam(
            self.discriminator.parameters(), lr=self.learning_rate, betas=(self.bias1, self.bias2)
        )
        return [opt_generator, opt_discriminator], []


    def forward(self, z: torch.Tensor) -> torch.Tensor:
        """
        Forward propagation.

        Parameters
        ----------
        z : torch.Tensorh
            The latent vector.

        Returns
        -------
        torch.Tensor
            The output.
        """
        return self.generator(z)


    def training_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, optimizer_idx: int
    ) -> Dict:
        """
        Training step.

        Parameters
        ----------
        batch : Tuple[torch.Tensor, torch.Tensor]
            The batch.
        batch_idx : int
            The batch index.
        optimizer_idx : int
            The optimizer index.

        Returns
        -------
        Dict
            The training loss.
        """
        real_images = batch["tr_image"]

        if optimizer_idx == 0: # Only train the generator
            fake_random_noise = torch.randn(self.batch_size, self.latent_size, 1, 1)
            fake_random_noise = fake_random_noise.type_as(real_images)
            fake_images = self(fake_random_noise)

            # Try to fool the discriminator
            preds = self.discriminator(fake_images)
            loss = self.criterion(preds, torch.ones_like(preds))
            self.log("g_loss", loss, on_step=False, on_epoch=True)

            tqdm_dict = {"g_loss": loss}
            output = OrderedDict({"loss": loss, "progress_bar": tqdm_dict, "log": tqdm_dict})
            return output

        elif optimizer_idx == 1: # Only train the discriminator
            real_preds = self.discriminator(real_images)
            real_loss = self.criterion(real_preds, torch.ones_like(real_preds))

            # Generate fake images
            real_random_noise = torch.randn(self.batch_size, self.latent_size, 1, 1)
            real_random_noise = real_random_noise.type_as(real_images)
            fake_images = self(real_random_noise)

            # Pass fake images though discriminator
            fake_preds = self.discriminator(fake_images)
            fake_loss = self.criterion(fake_preds, torch.zeros_like(fake_preds))

            # Update discriminator weights
            loss = real_loss + fake_loss
            self.log("d_loss", loss, on_step=False, on_epoch=True)

            tqdm_dict = {"d_loss": loss}
            output = OrderedDict({"loss": loss, "progress_bar": tqdm_dict, "log": tqdm_dict})
            return output