'''
not exactly the same as the official repo but the results are good
'''
import sys
import os

from data_utils.lower_body import c_index_3d, c_index_6d

sys.path.append(os.getcwd())

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import math

from nets.base import TrainWrapperBaseClass
from nets.layers import SeqEncoder1D
from losses import KeypointLoss, L1Loss, KLLoss
from data_utils.utils import get_melspec, get_mfcc_psf, get_mfcc_ta
from nets.utils import denormalize

class Conv1d_tf(nn.Conv1d):
    """
    Conv1d with the padding behavior from TF
    modified from https://github.com/mlperf/inference/blob/482f6a3beb7af2fb0bd2d91d6185d5e71c22c55f/others/edge/object_detection/ssd_mobilenet/pytorch/utils.py
    """

    def __init__(self, *args, **kwargs):
        super(Conv1d_tf, self).__init__(*args, **kwargs)
        self.padding = kwargs.get("padding", "same")

    def _compute_padding(self, input, dim):
        input_size = input.size(dim + 2)
        filter_size = self.weight.size(dim + 2)
        effective_filter_size = (filter_size - 1) * self.dilation[dim] + 1
        out_size = (input_size + self.stride[dim] - 1) // self.stride[dim]
        total_padding = max(
            0, (out_size - 1) * self.stride[dim] + effective_filter_size - input_size
        )
        additional_padding = int(total_padding % 2 != 0)

        return additional_padding, total_padding

    def forward(self, input):
        if self.padding == "VALID":
            return F.conv1d(
                input,
                self.weight,
                self.bias,
                self.stride,
                padding=0,
                dilation=self.dilation,
                groups=self.groups,
            )
        rows_odd, padding_rows = self._compute_padding(input, dim=0)
        if rows_odd:
            input = F.pad(input, [0, rows_odd])

        return F.conv1d(
            input,
            self.weight,
            self.bias,
            self.stride,
            padding=(padding_rows // 2),
            dilation=self.dilation,
            groups=self.groups,
        )


def ConvNormRelu(in_channels, out_channels, type='1d', downsample=False, k=None, s=None, norm='bn', padding='valid'):
    if k is None and s is None:
        if not downsample:
            k = 3
            s = 1
        else:
            k = 4
            s = 2

    if type == '1d':
        conv_block = Conv1d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding)
        if norm == 'bn':
            norm_block = nn.BatchNorm1d(out_channels)
        elif norm == 'ln':
            norm_block = nn.LayerNorm(out_channels)
    elif type == '2d':
        conv_block = Conv2d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding)
        norm_block = nn.BatchNorm2d(out_channels)
    else:
        assert False

    return nn.Sequential(
        conv_block,
        norm_block,
        nn.LeakyReLU(0.2, True)
    )

class Decoder(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Decoder, self).__init__()
        self.up1 = nn.Sequential(
            ConvNormRelu(in_ch // 2 + in_ch, in_ch // 2),
            ConvNormRelu(in_ch // 2, in_ch // 2),
            nn.Upsample(scale_factor=2, mode='nearest')
        )
        self.up2 = nn.Sequential(
            ConvNormRelu(in_ch // 4 + in_ch // 2, in_ch // 4),
            ConvNormRelu(in_ch // 4, in_ch // 4),
            nn.Upsample(scale_factor=2, mode='nearest')
        )
        self.up3 = nn.Sequential(
            ConvNormRelu(in_ch // 8 + in_ch // 4, in_ch // 8),
            ConvNormRelu(in_ch // 8, in_ch // 8),
            nn.Conv1d(in_ch // 8, out_ch, 1, 1)
        )

    def forward(self, x, x1, x2, x3):
        x = F.interpolate(x, x3.shape[2])
        x = torch.cat([x, x3], dim=1)
        x = self.up1(x)
        x = F.interpolate(x, x2.shape[2])
        x = torch.cat([x, x2], dim=1)
        x = self.up2(x)
        x = F.interpolate(x, x1.shape[2])
        x = torch.cat([x, x1], dim=1)
        x = self.up3(x)
        return x


class EncoderDecoder(nn.Module):
    def __init__(self, n_frames, each_dim):
        super().__init__()
        self.n_frames = n_frames

        self.down1 = nn.Sequential(
            ConvNormRelu(64, 64, '1d', False),
            ConvNormRelu(64, 128, '1d', False),
        )
        self.down2 = nn.Sequential(
            ConvNormRelu(128, 128, '1d', False),
            ConvNormRelu(128, 256, '1d', False),
        )
        self.down3 = nn.Sequential(
            ConvNormRelu(256, 256, '1d', False),
            ConvNormRelu(256, 512, '1d', False),
        )
        self.down4 = nn.Sequential(
            ConvNormRelu(512, 512, '1d', False),
            ConvNormRelu(512, 1024, '1d', False),
        )

        self.down = nn.MaxPool1d(kernel_size=2)
        self.up = nn.Upsample(scale_factor=2, mode='nearest')

        self.face_decoder = Decoder(1024, each_dim[0] + each_dim[3])
        self.body_decoder = Decoder(1024, each_dim[1])
        self.hand_decoder = Decoder(1024, each_dim[2])

    def forward(self, spectrogram, time_steps=None):
        if time_steps is None:
            time_steps = self.n_frames

        x1 = self.down1(spectrogram)
        x = self.down(x1)
        x2 = self.down2(x)
        x = self.down(x2)
        x3 = self.down3(x)
        x = self.down(x3)
        x = self.down4(x)
        x = self.up(x)

        face = self.face_decoder(x, x1, x2, x3)
        body = self.body_decoder(x, x1, x2, x3)
        hand = self.hand_decoder(x, x1, x2, x3)

        return face, body, hand


class Generator(nn.Module):
    def __init__(self,
                 each_dim,
                 training=False,
                 device=None
                 ):
        super().__init__()

        self.training = training
        self.device = device

        self.encoderdecoder = EncoderDecoder(15, each_dim)

    def forward(self, in_spec, time_steps=None):
        if time_steps is not None:
            self.gen_length = time_steps

        face, body, hand = self.encoderdecoder(in_spec)
        out = torch.cat([face, body, hand], dim=1)
        out = out.transpose(1, 2)

        return out


class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.net = nn.Sequential(
            ConvNormRelu(input_dim, 128, '1d'),
            ConvNormRelu(128, 256, '1d'),
            nn.MaxPool1d(kernel_size=2),
            ConvNormRelu(256, 256, '1d'),
            ConvNormRelu(256, 512, '1d'),
            nn.MaxPool1d(kernel_size=2),
            ConvNormRelu(512, 512, '1d'),
            ConvNormRelu(512, 1024, '1d'),
            nn.MaxPool1d(kernel_size=2),
            nn.Conv1d(1024, 1, 1, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.transpose(1, 2)

        out = self.net(x)
        return out


class TrainWrapper(TrainWrapperBaseClass):
    def __init__(self, args, config) -> None:
        self.args = args
        self.config = config
        self.device = torch.device(self.args.gpu)
        self.global_step = 0
        self.convert_to_6d = self.config.Data.pose.convert_to_6d
        self.init_params()

        self.generator = Generator(
            each_dim=self.each_dim,
            training=not self.args.infer,
            device=self.device,
        ).to(self.device)
        self.discriminator = Discriminator(
            input_dim=self.each_dim[1] + self.each_dim[2] + 64
        ).to(self.device)
        if self.convert_to_6d:
            self.c_index = c_index_6d
        else:
            self.c_index = c_index_3d
        self.MSELoss = KeypointLoss().to(self.device)
        self.L1Loss = L1Loss().to(self.device)
        super().__init__(args, config)

    def init_params(self):
        scale = 1

        global_orient = round(0 * scale)
        leye_pose = reye_pose = round(0 * scale)
        jaw_pose = round(3 * scale)
        body_pose = round((63 - 24) * scale)
        left_hand_pose = right_hand_pose = round(45 * scale)

        expression = 100

        b_j = 0
        jaw_dim = jaw_pose
        b_e = b_j + jaw_dim
        eye_dim = leye_pose + reye_pose
        b_b = b_e + eye_dim
        body_dim = global_orient + body_pose
        b_h = b_b + body_dim
        hand_dim = left_hand_pose + right_hand_pose
        b_f = b_h + hand_dim
        face_dim = expression

        self.dim_list = [b_j, b_e, b_b, b_h, b_f]
        self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim
        self.pose = int(self.full_dim / round(3 * scale))
        self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]

    def __call__(self, bat):
        assert (not self.args.infer), "infer mode"
        self.global_step += 1

        loss_dict = {}

        aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
        expression = bat['expression'].to(self.device).to(torch.float32)
        jaw = poses[:, :3, :]
        poses = poses[:, self.c_index, :]

        pred = self.generator(in_spec=aud)

        D_loss, D_loss_dict = self.get_loss(
            pred_poses=pred.detach(),
            gt_poses=poses,
            aud=aud,
            mode='training_D',
        )

        self.discriminator_optimizer.zero_grad()
        D_loss.backward()
        self.discriminator_optimizer.step()

        G_loss, G_loss_dict = self.get_loss(
            pred_poses=pred,
            gt_poses=poses,
            aud=aud,
            expression=expression,
            jaw=jaw,
            mode='training_G',
        )
        self.generator_optimizer.zero_grad()
        G_loss.backward()
        self.generator_optimizer.step()

        total_loss = None
        loss_dict = {}
        for key in list(D_loss_dict.keys()) + list(G_loss_dict.keys()):
            loss_dict[key] = G_loss_dict.get(key, 0) + D_loss_dict.get(key, 0)

        return total_loss, loss_dict

    def get_loss(self,
                 pred_poses,
                 gt_poses,
                 aud=None,
                 jaw=None,
                 expression=None,
                 mode='training_G',
                 ):
        loss_dict = {}
        aud = aud.transpose(1, 2)
        gt_poses = gt_poses.transpose(1, 2)
        gt_aud = torch.cat([gt_poses, aud], dim=2)
        pred_aud = torch.cat([pred_poses[:, :, 103:], aud], dim=2)

        if mode == 'training_D':
            dis_real = self.discriminator(gt_aud)
            dis_fake = self.discriminator(pred_aud)
            dis_error = self.MSELoss(torch.ones_like(dis_real).to(self.device), dis_real) + self.MSELoss(
                torch.zeros_like(dis_fake).to(self.device), dis_fake)
            loss_dict['dis'] = dis_error

            return dis_error, loss_dict
        elif mode == 'training_G':
            jaw_loss = self.L1Loss(pred_poses[:, :, :3], jaw.transpose(1, 2))
            face_loss = self.MSELoss(pred_poses[:, :, 3:103], expression.transpose(1, 2))
            body_loss = self.L1Loss(pred_poses[:, :, 103:142], gt_poses[:, :, :39])
            hand_loss = self.L1Loss(pred_poses[:, :, 142:], gt_poses[:, :, 39:])
            l1_loss = jaw_loss + face_loss + body_loss + hand_loss

            dis_output = self.discriminator(pred_aud)
            gen_error = self.MSELoss(torch.ones_like(dis_output).to(self.device), dis_output)
            gen_loss = self.config.Train.weights.keypoint_loss_weight * l1_loss + self.config.Train.weights.gan_loss_weight * gen_error

            loss_dict['gen'] = gen_error
            loss_dict['jaw_loss'] = jaw_loss
            loss_dict['face_loss'] = face_loss
            loss_dict['body_loss'] = body_loss
            loss_dict['hand_loss'] = hand_loss
            return gen_loss, loss_dict
        else:
            raise ValueError(mode)

    def infer_on_audio(self, aud_fn, fps=30, initial_pose=None, norm_stats=None, id=None, B=1, **kwargs):
        output = []
        assert self.args.infer, "train mode"
        self.generator.eval()

        if self.config.Data.pose.normalization:
            assert norm_stats is not None
            data_mean = norm_stats[0]
            data_std = norm_stats[1]

        pre_length = self.config.Data.pose.pre_pose_length
        generate_length = self.config.Data.pose.generate_length
        # assert pre_length == initial_pose.shape[-1]
        # pre_poses = initial_pose.permute(0, 2, 1).to(self.device).to(torch.float32)
        # B = pre_poses.shape[0]

        aud_feat = get_mfcc_ta(aud_fn, sr=22000, fps=fps, smlpx=True, type='mfcc').transpose(1, 0)
        num_poses_to_generate = aud_feat.shape[-1]
        aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0)
        aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.device)

        with torch.no_grad():
            pred_poses = self.generator(aud_feat)
            pred_poses = pred_poses.cpu().numpy()
        output = pred_poses.squeeze()

        return output

    def generate(self, aud, id):
        self.generator.eval()
        pred_poses = self.generator(aud)
        return pred_poses


if __name__ == '__main__':
    from trainer.options import parse_args

    parser = parse_args()
    args = parser.parse_args(
        ['--exp_name', '0', '--data_root', '0', '--speakers', '0', '--pre_pose_length', '4', '--generate_length', '64',
         '--infer'])

    generator = TrainWrapper(args)

    aud_fn = '../sample_audio/jon.wav'
    initial_pose = torch.randn(64, 108, 4)
    norm_stats = (np.random.randn(108), np.random.randn(108))
    output = generator.infer_on_audio(aud_fn, initial_pose, norm_stats)

    print(output.shape)