import logging as log
import os
from pathlib import Path

import torch
import torchvision.transforms as transforms
import torchvision.models as models
from torch import nn
from torch.nn import functional as F

from enum import Enum


class AdvEnum(Enum):
    @classmethod
    def list(cls):
        return list(map(lambda c: c.value, cls))

    @classmethod
    def list_name_value(cls):
        return list(map(lambda c: (c.name, c.value), cls))


class DecoNetMode(AdvEnum):
    FREEZE_DECO = 0
    FREEZE_PTMODEL = 1
    FREEZE_PTMODEL_NO_FC = 2
    UNFREEZE_ALL = 3
    FREEZE_ALL = 4
    FREEZE_ALL_NO_FC = 5


class DecoType(AdvEnum):
    NO = 0
    DECONV = 1
    RESIZE_CONV = 2
    ColorUDECO = 16
    PIXEL_SHUFFLE = 20


def get_deco_model(use_deco, out_deco) -> nn.Module:
    if use_deco in [DecoType.DECONV, DecoType.DECONV_NORM]:
        return StandardDECO(out_deco, deconv=True)
    elif use_deco in [DecoType.RESIZE_CONV]:
        return StandardDECO(out_deco, deconv=False)
    elif use_deco is DecoType.PIXEL_SHUFFLE:
        return PixelShuffle(out_deco, lrelu=False)
    elif use_deco is DecoType.ColorUDECO:
        return ColorUDECO(out_deco)
    else:
        raise ValueError("Module not found")


class PreTrainedModel(AdvEnum):
    DENSENET_121 = 0
    RESNET_18 = 1
    RESNET_34 = 2
    RESNET_50 = 3
    VGG11 = 4
    VGG11_BN = 5


def get_pt_model(model, output, pretrained=True):
    input = 224
    if not isinstance(model, PreTrainedModel):
        model = PreTrainedModel(model)
    pt_model = None
    if model == PreTrainedModel.DENSENET_121:
        pt_model = models.densenet121(pretrained=pretrained)
        num_ftrs = pt_model.classifier.in_features
        pt_model.classifier = nn.Linear(num_ftrs, output)
        pt_model.last_layer_name = "classifier"
    elif model == PreTrainedModel.RESNET_18:
        pt_model = models.resnet18(pretrained=pretrained)
        num_ftrs = pt_model.fc.in_features
        pt_model.fc = nn.Linear(num_ftrs, output)
        pt_model.last_layer_name = "fc"
    elif model == PreTrainedModel.RESNET_34:
        pt_model = models.resnet34(pretrained=pretrained)
        num_ftrs = pt_model.fc.in_features
        pt_model.fc = nn.Linear(num_ftrs, output)
        pt_model.last_layer_name = "fc"
    elif model == PreTrainedModel.RESNET_50:
        pt_model = models.resnet50(pretrained=pretrained)
        num_ftrs = pt_model.fc.in_features
        pt_model.fc = nn.Linear(num_ftrs, output)
        pt_model.last_layer_name = "fc"
    elif model == PreTrainedModel.VGG11:
        pt_model = models.vgg11(pretrained=pretrained)
        num_ftrs = pt_model.classifier[6].in_features
        pt_model.classifier[6] = nn.Linear(num_ftrs, output)
        pt_model.last_layer_name = "classifier.6"
    elif model == PreTrainedModel.VGG11_BN:
        pt_model = models.vgg11_bn(pretrained=pretrained)
        num_ftrs = pt_model.classifier[6].in_features
        pt_model.classifier[6] = nn.Linear(num_ftrs, output)
        pt_model.last_layer_name = "classifier.6"
    else:
        raise ValueError("Model not found")

    return pt_model, input


class DecoNet(nn.Module):
    """
    Colorization module(optional)+Model
    """

    def __init__(self, output=14,
                 deco_type=DecoType.ColorUDECO,
                 pt_model=PreTrainedModel.RESNET_18,
                 pre_trained=True,
                 training_mode=DecoNetMode.FREEZE_PTMODEL_NO_FC,
                 use_aap=False):
        super().__init__()
        # Pre-trained Model
        self.deco_type = deco_type
        self.training_mode = training_mode
        self.use_aap = use_aap
        pt_model, self.out_deco = get_pt_model(pt_model, output, pre_trained)
        self.last_layer_name = pt_model.last_layer_name
        # DECO if needed
        if self.deco_type is not DecoType.NO:
            self.deco = get_deco_model(self.deco_type, self.out_deco)
        else:
            self.deco = None
        self.pt_model = pt_model
        self.set_mode(training_mode)

    def set_mode(self, mode, print=True):
        if not isinstance(mode, DecoNetMode):
            mode = DecoNetMode(mode)
        if mode == DecoNetMode.UNFREEZE_ALL:
            for param in self.parameters():
                param.requires_grad = True
        elif mode == DecoNetMode.FREEZE_DECO:
            self.set_mode(DecoNetMode.UNFREEZE_ALL, False)
            for param in self.deco.parameters():
                param.requires_grad = False
        elif mode == DecoNetMode.FREEZE_PTMODEL:
            self.set_mode(DecoNetMode.UNFREEZE_ALL, False)
            for param in self.pt_model.parameters():
                param.requires_grad = False
        elif mode == DecoNetMode.FREEZE_PTMODEL_NO_FC:
            self.set_mode(DecoNetMode.UNFREEZE_ALL, False)
            for name, param in self.pt_model.named_parameters():
                if self.last_layer_name not in name:
                    param.requires_grad = False
        elif mode == DecoNetMode.FREEZE_ALL:
            for param in self.parameters():
                param.requires_grad = False
        elif mode == DecoNetMode.FREEZE_ALL_NO_FC:
            self.set_mode(DecoNetMode.FREEZE_ALL, False)
            # Unfreeze last layer
            for name, param in self.pt_model.named_parameters():
                if self.last_layer_name in name:
                    param.requires_grad = True

        if print:
            log.info("#############################################")
            log.info("PARAMETERS STATUS:")
            for name, param in self.named_parameters():
                log.info("{} : {}".format(name, param.requires_grad))
            log.info("#############################################")

    def get_layer_weight(self, sel_name: str = ""):
        if sel_name == "":
            sel_name = self.last_layer_name
        res = []
        for name, param in self.pt_model.named_parameters():
            if sel_name in name:
                res.append(param)

        return res

    def forward(self, xb):
        """
        @:param xb : tensor
          Batch of input images

        @:return tensor
          A batch of output images
        """
        if self.deco is not None:
            xb = self.deco(xb)
            if self.use_aap:
                xb = F.adaptive_avg_pool2d(xb, (self.out_deco, self.out_deco))
        return self.pt_model(xb)

    def clean_last_layer(self):
        pt_model_type = self.pt_model

        if pt_model_type == PreTrainedModel.VGG11_BN or pt_model_type == PreTrainedModel.VGG11:
            self.pt_model.classifier[6].reset_parameters()
        else:
            last_layer_name = list(self.pt_model._modules)[-1]
            self.pt_model._modules[last_layer_name].reset_parameters()

        log.info("Last layer cleaned!")

    def last_layer_size(self):
        pt_model_type = self.pt_model
        if pt_model_type == PreTrainedModel.VGG11_BN or pt_model_type == PreTrainedModel.VGG11:
            return self.pt_model.classifier[6].weight.shape[-1]
        else:
            last_layer_name = list(self.pt_model._modules)[-1]
            return self.pt_model._modules[last_layer_name].shape[-1]

    def load_deco_state_dict(self, state_dict):
        if self.deco is None:
            self.deco = get_deco_model(self.deco_type, self.out_deco)
        if hasattr(self.deco, "load_state_dict"):
            self.deco.load_state_dict(state_dict)
        else:
            return False
        self.set_mode(self.training_mode)
        return True


def default_deco__weight_init(m):
    if isinstance(m, nn.Conv2d):
        # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        # m.weight.data.normal_(0, math.sqrt(2. / n))
        torch.nn.init.xavier_uniform_(m.weight)
    elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()


def bn_weight_init(m):
    if isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()


class BaseDECO(nn.Module):
    def __init__(self, out=224, init=None):
        super().__init__()
        self.out_s = out
        self.init = init

    def set_output_size(self, out_s):
        self.out_s = out_s

    def init_weights(self):
        if self.init is None:
            pass
        elif self.init == 0:
            self.apply(default_deco__weight_init)
        elif self.init == 1:
            self.apply(bn_weight_init)


class ResBlock(nn.Module):
    def __init__(self, ni, nf=None, kernel=3, stride=1, padding=1):
        super().__init__()
        if nf is None:
            nf = ni
        self.conv1 = conv_layer(ni, nf, kernel=kernel, stride=stride, padding=padding)
        self.conv2 = conv_layer(nf, nf, kernel=kernel, stride=stride, padding=padding)

    def forward(self, x):
        return x + self.conv2(self.conv1(x))


def conv_layer(in_layer, out_layer, kernel=3, stride=1, padding=1, instanceNorm=False):
    return nn.Sequential(
        nn.Conv2d(in_layer, out_layer, kernel_size=kernel, stride=stride, padding=padding),
        nn.BatchNorm2d(out_layer) if not instanceNorm else nn.InstanceNorm2d(out_layer),
        nn.LeakyReLU(inplace=True)
    )


def _make_res_layers(nl, ni, kernel=3, stride=1, padding=1):
    layers = []
    for i in range(nl):
        layers.append(ResBlock(ni, kernel=kernel, stride=stride, padding=padding))

    return nn.Sequential(*layers)


class StandardDECO(BaseDECO):
    """
    Standard DECO Module
    """

    def __init__(self, out=224, init=0, deconv=False):
        super().__init__(out, init)
        self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=2)
        self.bn1 = nn.BatchNorm2d(64)
        # ReLU
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.resblocks = _make_res_layers(8, 64)
        self.conv_last = nn.Conv2d(64, 3, kernel_size=1)
        self.deconv = deconv
        if deconv:
            # TODO: Check if use "groups = 1"
            self.deconv = nn.ConvTranspose2d(in_channels=3, out_channels=3, kernel_size=8, padding=2, stride=4,
                                             groups=3, bias=False)
        else:
            self.pad = nn.ReflectionPad2d(1)
            self.conv_up = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=0, stride=1)

        self.init_weights()

    def forward(self, xb):
        """
        @:param xb : Tensor
          Batch of input images

        @:return tensor
          A batch of output images
        """
        _xb = self.maxpool(F.leaky_relu(self.bn1(self.conv1(xb))))
        _xb = self.resblocks(_xb)
        _xb = self.conv_last(_xb)
        if self.deconv:
            _xb = self.deconv(_xb, output_size=xb.shape)
        else:
            _xb = self.conv_up(self.pad(F.interpolate(_xb, scale_factor=4, mode='nearest')))
        return _xb


def icnr(x, scale=4, init=nn.init.kaiming_normal_):
    """ ICNR init of `x`, with `scale` and `init` function.

        Checkerboard artifact free sub-pixel convolution: https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf
    """
    ni, nf, h, w = x.shape
    ni2 = int(ni / (scale ** 2))
    k = init(torch.zeros([ni2, nf, h, w])).transpose(0, 1)
    k = k.contiguous().view(ni2, nf, -1)
    k = k.repeat(1, 1, scale ** 2)
    k = k.contiguous().view([nf, ni, h, w]).transpose(0, 1)
    x.data.copy_(k)


class PixelShuffle_ICNR(nn.Module):
    """ Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init,
        and `weight_norm`.

        "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts":
        https://arxiv.org/abs/1806.02658
    """

    def __init__(self, ni: int, nf: int = None, scale: int = 4, icnr_init=True, blur_k=2, blur_s=1,
                 blur_pad=(1, 0, 1, 0), lrelu=True):
        super().__init__()
        nf = ni if nf is None else nf
        self.conv = conv_layer(ni, nf * (scale ** 2), kernel=1, padding=0, stride=1) if lrelu else nn.Sequential(
            nn.Conv2d(64, 3 * (scale ** 2), 1, 1, 0), nn.BatchNorm2d(3 * (scale ** 2)))
        if icnr_init:
            icnr(self.conv[0].weight, scale=scale)
        self.act = nn.LeakyReLU(inplace=False) if lrelu else nn.Hardtanh(-10000, 10000)
        self.shuf = nn.PixelShuffle(scale)
        # Blurring over (h*w) kernel
        self.pad = nn.ReplicationPad2d(blur_pad)
        self.blur = nn.AvgPool2d(blur_k, stride=blur_s)

    def forward(self, x):
        x = self.shuf(self.act(self.conv(x)))
        return self.blur(self.pad(x))


class PixelShuffle(BaseDECO):
    """
    PixelShuffle Module
    """

    def __init__(self, out=224, init=1, scale=4, lrelu=False):
        super().__init__(out, init)
        # Which value should I use for stride and padding?
        self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=2)
        self.bn1 = nn.BatchNorm2d(64)
        self.act1 = nn.LeakyReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.resblocks = _make_res_layers(8, 64)
        self.pixel_shuffle = PixelShuffle_ICNR(ni=64, nf=3, scale=scale, lrelu=lrelu)
        self.init_weights()

    def forward(self, xb):
        """
        @:param xb : Tensor
          Batch of input images

        @:return tensor
          A batch of output images
        """
        _xb = self.maxpool(self.act1(self.bn1(self.conv1(xb))))
        _xb = self.resblocks(_xb)

        return self.pixel_shuffle(_xb)


class ColorUDECO(BaseDECO):
    """
    ColorUDECO Module
    """

    def __init__(self, out=224, init=0, in_ch=1, out_ch=3):
        super().__init__(out, init)
        self.dw1 = ColorDown(in_ch, 16)
        self.dw2 = ColorDown(16, 32)
        self.dw3 = ColorDown(32, 64)
        self.up1 = ColorUp(64, 32)
        self.up2 = ColorUp(64, 16)
        self.out = ColorOut(32, 16, out_ch)

    def forward(self, x1):
        """
        @:param x1 : Tensor
          Batch of input images

        @:return tensor
          A batch of output images
        """
        x1 = self.dw1(x1)
        x2 = self.dw2(x1)
        x3 = self.dw3(x2)
        x3 = self.up1(x3)
        x2 = self.up2(torch.cat([x2, x3], dim=1))
        return self.out(torch.cat([x1, x2], dim=1))


class ColorDown(nn.Module):
    def __init__(self, in_ch, out_ch, htan=False):
        super(ColorDown, self).__init__()
        self.d = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU() if not htan else nn.Hardtanh(),
            nn.Conv2d(out_ch, out_ch, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU() if not htan else nn.Hardtanh(),
            nn.BatchNorm2d(out_ch)
        )

    def forward(self, x):
        return self.d(x)


class ColorUp(nn.Module):
    def __init__(self, in_ch, out_ch, htan=False):
        super(ColorUp, self).__init__()
        self.u = nn.Sequential(
            nn.ConvTranspose2d(in_ch, out_ch, 4, 2, 1),
            nn.LeakyReLU() if not htan else nn.Hardtanh(),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU() if not htan else nn.Hardtanh(),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU() if not htan else nn.Hardtanh(),
            nn.BatchNorm2d(out_ch)
        )

    def forward(self, x):
        return self.u(x)


class ColorOut(nn.Module):
    def __init__(self, in_ch, out_ch, out_last, htan=False):
        super(ColorOut, self).__init__()
        self.u = nn.Sequential(
            nn.ConvTranspose2d(in_ch, out_ch, 4, 2, 1),
            nn.LeakyReLU() if not htan else nn.Hardtanh(),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU() if not htan else nn.Hardtanh(),
            nn.Conv2d(out_ch, out_last, kernel_size=1, stride=1, padding=0),
        )

    def forward(self, x):
        return self.u(x)