import json
from collections import OrderedDict
from math import exp

from .Common import *


# +++++++++++++++++++++++++++++++++++++
#           FP16 Training
# -------------------------------------
#  Modified from Nvidia/Apex
# https://github.com/NVIDIA/apex/blob/master/apex/fp16_utils/fp16util.py


class tofp16(nn.Module):
    def __init__(self):
        super(tofp16, self).__init__()

    def forward(self, input):
        if input.is_cuda:
            return input.half()
        else:  # PyTorch 1.0 doesn't support fp16 in CPU
            return input.float()


def BN_convert_float(module):
    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
        module.float()
    for child in module.children():
        BN_convert_float(child)
    return module


def network_to_half(network):
    return nn.Sequential(tofp16(), BN_convert_float(network.half()))


# warnings.simplefilter('ignore')

# +++++++++++++++++++++++++++++++++++++
#           DCSCN
# -------------------------------------


class DCSCN(BaseModule):
    # https://github.com/jiny2001/dcscn-super-resolution
    def __init__(
        self,
        color_channel=3,
        up_scale=2,
        feature_layers=12,
        first_feature_filters=196,
        last_feature_filters=48,
        reconstruction_filters=128,
        up_sampler_filters=32,
    ):
        super(DCSCN, self).__init__()
        self.total_feature_channels = 0
        self.total_reconstruct_filters = 0
        self.upscale = up_scale

        self.act_fn = nn.SELU(inplace=False)
        self.feature_block = self.make_feature_extraction_block(
            color_channel, feature_layers, first_feature_filters, last_feature_filters
        )

        self.reconstruction_block = self.make_reconstruction_block(
            reconstruction_filters
        )
        self.up_sampler = self.make_upsampler(up_sampler_filters, color_channel)
        self.selu_init_params()

    def selu_init_params(self):
        for i in self.modules():
            if isinstance(i, nn.Conv2d):
                i.weight.data.normal_(0.0, 1.0 / sqrt(i.weight.numel()))
                if i.bias is not None:
                    i.bias.data.fill_(0)

    def conv_block(self, in_channel, out_channel, kernel_size):
        m = OrderedDict(
            [
                # ("Padding", nn.ReplicationPad2d((kernel_size - 1) // 2)),
                (
                    "Conv2d",
                    nn.Conv2d(
                        in_channel,
                        out_channel,
                        kernel_size=kernel_size,
                        padding=(kernel_size - 1) // 2,
                    ),
                ),
                ("Activation", self.act_fn),
            ]
        )

        return nn.Sequential(m)

    def make_feature_extraction_block(
        self, color_channel, num_layers, first_filters, last_filters
    ):
        # input layer
        feature_block = [
            ("Feature 1", self.conv_block(color_channel, first_filters, 3))
        ]
        # exponential decay
        # rest layers
        alpha_rate = log(first_filters / last_filters) / (num_layers - 1)
        filter_nums = [
            round(first_filters * exp(-alpha_rate * i)) for i in range(num_layers)
        ]

        self.total_feature_channels = sum(filter_nums)

        layer_filters = [
            [filter_nums[i], filter_nums[i + 1], 3] for i in range(num_layers - 1)
        ]

        feature_block.extend(
            [
                ("Feature {}".format(index + 2), self.conv_block(*x))
                for index, x in enumerate(layer_filters)
            ]
        )
        return nn.Sequential(OrderedDict(feature_block))

    def make_reconstruction_block(self, num_filters):
        B1 = self.conv_block(self.total_feature_channels, num_filters // 2, 1)
        B2 = self.conv_block(num_filters // 2, num_filters, 3)
        m = OrderedDict(
            [
                ("A", self.conv_block(self.total_feature_channels, num_filters, 1)),
                ("B", nn.Sequential(*[B1, B2])),
            ]
        )
        self.total_reconstruct_filters = num_filters * 2
        return nn.Sequential(m)

    def make_upsampler(self, out_channel, color_channel):
        out = out_channel * self.upscale**2
        m = OrderedDict(
            [
                (
                    "Conv2d_block",
                    self.conv_block(self.total_reconstruct_filters, out, kernel_size=3),
                ),
                ("PixelShuffle", nn.PixelShuffle(self.upscale)),
                (
                    "Conv2d",
                    nn.Conv2d(
                        out_channel, color_channel, kernel_size=3, padding=1, bias=False
                    ),
                ),
            ]
        )

        return nn.Sequential(m)

    def forward(self, x):
        # residual learning
        lr, lr_up = x
        feature = []
        for layer in self.feature_block.children():
            lr = layer(lr)
            feature.append(lr)
        feature = torch.cat(feature, dim=1)

        reconstruction = [
            layer(feature) for layer in self.reconstruction_block.children()
        ]
        reconstruction = torch.cat(reconstruction, dim=1)

        lr = self.up_sampler(reconstruction)
        return lr + lr_up


# +++++++++++++++++++++++++++++++++++++
#           CARN
# -------------------------------------


class CARN_Block(BaseModule):
    def __init__(
        self,
        channels,
        kernel_size=3,
        padding=1,
        dilation=1,
        groups=1,
        activation=nn.SELU(),
        repeat=3,
        SEBlock=False,
        conv=nn.Conv2d,
        single_conv_size=1,
        single_conv_group=1,
    ):
        super(CARN_Block, self).__init__()
        m = []
        for i in range(repeat):
            m.append(
                ResidualFixBlock(
                    channels,
                    channels,
                    kernel_size=kernel_size,
                    padding=padding,
                    dilation=dilation,
                    groups=groups,
                    activation=activation,
                    conv=conv,
                )
            )
            if SEBlock:
                m.append(SpatialChannelSqueezeExcitation(channels, reduction=channels))
        self.blocks = nn.Sequential(*m)
        self.singles = nn.Sequential(
            *[
                ConvBlock(
                    channels * (i + 2),
                    channels,
                    kernel_size=single_conv_size,
                    padding=(single_conv_size - 1) // 2,
                    groups=single_conv_group,
                    activation=activation,
                    conv=conv,
                )
                for i in range(repeat)
            ]
        )

    def forward(self, x):
        c0 = x
        for block, single in zip(self.blocks, self.singles):
            b = block(x)
            c0 = c = torch.cat([c0, b], dim=1)
            x = single(c)

        return x


class CARN(BaseModule):
    # Fast, Accurate, and Lightweight Super-Resolution with Cascading Residual Network
    # https://github.com/nmhkahn/CARN-pytorch
    def __init__(
        self,
        color_channels=3,
        mid_channels=64,
        scale=2,
        activation=nn.SELU(),
        num_blocks=3,
        conv=nn.Conv2d,
    ):
        super(CARN, self).__init__()

        self.color_channels = color_channels
        self.mid_channels = mid_channels
        self.scale = scale

        self.entry_block = ConvBlock(
            color_channels,
            mid_channels,
            kernel_size=3,
            padding=1,
            activation=activation,
            conv=conv,
        )
        self.blocks = nn.Sequential(
            *[
                CARN_Block(
                    mid_channels,
                    kernel_size=3,
                    padding=1,
                    activation=activation,
                    conv=conv,
                    single_conv_size=1,
                    single_conv_group=1,
                )
                for _ in range(num_blocks)
            ]
        )
        self.singles = nn.Sequential(
            *[
                ConvBlock(
                    mid_channels * (i + 2),
                    mid_channels,
                    kernel_size=1,
                    padding=0,
                    activation=activation,
                    conv=conv,
                )
                for i in range(num_blocks)
            ]
        )

        self.upsampler = UpSampleBlock(
            mid_channels, scale=scale, activation=activation, conv=conv
        )
        self.exit_conv = conv(mid_channels, color_channels, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.entry_block(x)
        c0 = x
        for block, single in zip(self.blocks, self.singles):
            b = block(x)
            c0 = c = torch.cat([c0, b], dim=1)
            x = single(c)
        x = self.upsampler(x)
        out = self.exit_conv(x)
        return out


class CARN_V2(CARN):
    def __init__(
        self,
        color_channels=3,
        mid_channels=64,
        scale=2,
        activation=nn.LeakyReLU(0.1),
        SEBlock=True,
        conv=nn.Conv2d,
        atrous=(1, 1, 1),
        repeat_blocks=3,
        single_conv_size=3,
        single_conv_group=1,
    ):
        super(CARN_V2, self).__init__(
            color_channels=color_channels,
            mid_channels=mid_channels,
            scale=scale,
            activation=activation,
            conv=conv,
        )

        num_blocks = len(atrous)
        m = []
        for i in range(num_blocks):
            m.append(
                CARN_Block(
                    mid_channels,
                    kernel_size=3,
                    padding=1,
                    dilation=1,
                    activation=activation,
                    SEBlock=SEBlock,
                    conv=conv,
                    repeat=repeat_blocks,
                    single_conv_size=single_conv_size,
                    single_conv_group=single_conv_group,
                )
            )

        self.blocks = nn.Sequential(*m)

        self.singles = nn.Sequential(
            *[
                ConvBlock(
                    mid_channels * (i + 2),
                    mid_channels,
                    kernel_size=single_conv_size,
                    padding=(single_conv_size - 1) // 2,
                    groups=single_conv_group,
                    activation=activation,
                    conv=conv,
                )
                for i in range(num_blocks)
            ]
        )

    def forward(self, x):
        x = self.entry_block(x)
        c0 = x
        res = x
        for block, single in zip(self.blocks, self.singles):
            b = block(x)
            c0 = c = torch.cat([c0, b], dim=1)
            x = single(c)
        x = x + res
        x = self.upsampler(x)
        out = self.exit_conv(x)
        return out


# +++++++++++++++++++++++++++++++++++++
#           original Waifu2x model
# -------------------------------------


class UpConv_7(BaseModule):
    # https://github.com/nagadomi/waifu2x/blob/3c46906cb78895dbd5a25c3705994a1b2e873199/lib/srcnn.lua#L311
    def __init__(self):
        super(UpConv_7, self).__init__()
        self.act_fn = nn.LeakyReLU(0.1, inplace=False)
        self.offset = 7  # because of 0 padding
        from torch.nn import ZeroPad2d

        self.pad = ZeroPad2d(self.offset)
        m = [
            nn.Conv2d(3, 16, 3, 1, 0),
            self.act_fn,
            nn.Conv2d(16, 32, 3, 1, 0),
            self.act_fn,
            nn.Conv2d(32, 64, 3, 1, 0),
            self.act_fn,
            nn.Conv2d(64, 128, 3, 1, 0),
            self.act_fn,
            nn.Conv2d(128, 128, 3, 1, 0),
            self.act_fn,
            nn.Conv2d(128, 256, 3, 1, 0),
            self.act_fn,
            # in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=
            nn.ConvTranspose2d(256, 3, kernel_size=4, stride=2, padding=3, bias=False),
        ]
        self.Sequential = nn.Sequential(*m)

    def load_pre_train_weights(self, json_file):
        with open(json_file) as f:
            weights = json.load(f)
        box = []
        for i in weights:
            box.append(i["weight"])
            box.append(i["bias"])
        own_state = self.state_dict()
        for index, (name, param) in enumerate(own_state.items()):
            own_state[name].copy_(torch.FloatTensor(box[index]))

    def forward(self, x):
        x = self.pad(x)
        return self.Sequential.forward(x)


class Vgg_7(UpConv_7):
    def __init__(self):
        super(Vgg_7, self).__init__()
        self.act_fn = nn.LeakyReLU(0.1, inplace=False)
        self.offset = 7
        m = [
            nn.Conv2d(3, 32, 3, 1, 0),
            self.act_fn,
            nn.Conv2d(32, 32, 3, 1, 0),
            self.act_fn,
            nn.Conv2d(32, 64, 3, 1, 0),
            self.act_fn,
            nn.Conv2d(64, 64, 3, 1, 0),
            self.act_fn,
            nn.Conv2d(64, 128, 3, 1, 0),
            self.act_fn,
            nn.Conv2d(128, 128, 3, 1, 0),
            self.act_fn,
            nn.Conv2d(128, 3, 3, 1, 0),
        ]
        self.Sequential = nn.Sequential(*m)