# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Adapted from https://github.com/zhenye234/CoMoSpeech"""

import torch
import torch.nn as nn
import copy
import numpy as np
import math
from tqdm.auto import tqdm

from utils.ssim import SSIM

from models.svc.transformer.conformer import Conformer, BaseModule
from models.svc.diffusion.diffusion_wrapper import DiffusionWrapper
from models.svc.comosvc.utils import slice_segments, rand_ids_segments


class Consistency(nn.Module):
    def __init__(self, cfg, distill=False):
        super().__init__()
        self.cfg = cfg
        # self.denoise_fn = GradLogPEstimator2d(96)
        self.denoise_fn = DiffusionWrapper(self.cfg)
        self.cfg = cfg.model.comosvc
        self.teacher = not distill
        self.P_mean = self.cfg.P_mean
        self.P_std = self.cfg.P_std
        self.sigma_data = self.cfg.sigma_data
        self.sigma_min = self.cfg.sigma_min
        self.sigma_max = self.cfg.sigma_max
        self.rho = self.cfg.rho
        self.N = self.cfg.n_timesteps
        self.ssim_loss = SSIM()

        # Time step discretization
        step_indices = torch.arange(self.N)
        # karras boundaries formula
        t_steps = (
            self.sigma_min ** (1 / self.rho)
            + step_indices
            / (self.N - 1)
            * (self.sigma_max ** (1 / self.rho) - self.sigma_min ** (1 / self.rho))
        ) ** self.rho
        self.t_steps = torch.cat(
            [torch.zeros_like(t_steps[:1]), self.round_sigma(t_steps)]
        )

    def init_consistency_training(self):
        self.denoise_fn_ema = copy.deepcopy(self.denoise_fn)
        self.denoise_fn_pretrained = copy.deepcopy(self.denoise_fn)

    def EDMPrecond(self, x, sigma, cond, denoise_fn, mask, spk=None):
        """
        karras diffusion reverse process

        Args:
            x: noisy mel-spectrogram [B x n_mel x L]
            sigma: noise level [B x 1 x 1]
            cond: output of conformer encoder [B x n_mel x L]
            denoise_fn: denoiser neural network e.g. DilatedCNN
            mask: mask of padded frames [B x n_mel x L]

        Returns:
            denoised mel-spectrogram [B x n_mel x L]
        """
        sigma = sigma.reshape(-1, 1, 1)

        c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
        c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
        c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
        c_noise = sigma.log() / 4

        x_in = c_in * x
        x_in = x_in.transpose(1, 2)
        x = x.transpose(1, 2)
        cond = cond.transpose(1, 2)
        F_x = denoise_fn(x_in, c_noise.squeeze(), cond)
        # F_x =  denoise_fn((c_in * x), mask, cond, c_noise.flatten())
        D_x = c_skip * x + c_out * (F_x)
        D_x = D_x.transpose(1, 2)
        return D_x

    def EDMLoss(self, x_start, cond, mask):
        """
        compute loss for EDM model

        Args:
            x_start: ground truth mel-spectrogram [B x n_mel x L]
            cond: output of conformer encoder [B x n_mel x L]
            mask: mask of padded frames [B x n_mel x L]
        """
        rnd_normal = torch.randn([x_start.shape[0], 1, 1], device=x_start.device)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2

        # follow Grad-TTS, start from Gaussian noise with mean cond and std I
        noise = (torch.randn_like(x_start) + cond) * sigma
        D_yn = self.EDMPrecond(x_start + noise, sigma, cond, self.denoise_fn, mask)
        loss = weight * ((D_yn - x_start) ** 2)
        loss = torch.sum(loss * mask) / torch.sum(mask)
        return loss

    def round_sigma(self, sigma):
        return torch.as_tensor(sigma)

    def edm_sampler(
        self,
        latents,
        cond,
        nonpadding,
        num_steps=50,
        sigma_min=0.002,
        sigma_max=80,
        rho=7,
        S_churn=0,
        S_min=0,
        S_max=float("inf"),
        S_noise=1,
        # S_churn=40 ,S_min=0.05,S_max=50,S_noise=1.003,# S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
        # S_churn=30 ,S_min=0.01,S_max=30,S_noise=1.007,
        # S_churn=30 ,S_min=0.01,S_max=1,S_noise=1.007,
        # S_churn=80 ,S_min=0.05,S_max=50,S_noise=1.003,
    ):
        """
        karras diffusion sampler

        Args:
            latents: noisy mel-spectrogram [B x n_mel x L]
            cond: output of conformer encoder [B x n_mel x L]
            nonpadding: mask of padded frames [B x n_mel x L]
            num_steps: number of steps for diffusion inference

        Returns:
            denoised mel-spectrogram [B x n_mel x L]
        """
        # Time step discretization.
        step_indices = torch.arange(num_steps, device=latents.device)

        num_steps = num_steps + 1
        t_steps = (
            sigma_max ** (1 / rho)
            + step_indices
            / (num_steps - 1)
            * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))
        ) ** rho
        t_steps = torch.cat([self.round_sigma(t_steps), torch.zeros_like(t_steps[:1])])

        # Main sampling loop.
        x_next = latents * t_steps[0]
        # wrap in tqdm for progress bar
        bar = tqdm(enumerate(zip(t_steps[:-1], t_steps[1:])))
        for i, (t_cur, t_next) in bar:
            x_cur = x_next
            # Increase noise temporarily.
            gamma = (
                min(S_churn / num_steps, np.sqrt(2) - 1)
                if S_min <= t_cur <= S_max
                else 0
            )
            t_hat = self.round_sigma(t_cur + gamma * t_cur)
            t = torch.zeros((x_cur.shape[0], 1, 1), device=x_cur.device)
            t[:, 0, 0] = t_hat
            t_hat = t
            x_hat = x_cur + (
                t_hat**2 - t_cur**2
            ).sqrt() * S_noise * torch.randn_like(x_cur)
            # Euler step.
            denoised = self.EDMPrecond(x_hat, t_hat, cond, self.denoise_fn, nonpadding)
            d_cur = (x_hat - denoised) / t_hat
            x_next = x_hat + (t_next - t_hat) * d_cur

        return x_next

    def CTLoss_D(self, y, cond, mask):
        """
        compute loss for consistency distillation

        Args:
            y: ground truth mel-spectrogram [B x n_mel x L]
            cond: output of conformer encoder [B x n_mel x L]
            mask: mask of padded frames [B x n_mel x L]
        """
        with torch.no_grad():
            mu = 0.95
            for p, ema_p in zip(
                self.denoise_fn.parameters(), self.denoise_fn_ema.parameters()
            ):
                ema_p.mul_(mu).add_(p, alpha=1 - mu)

        n = torch.randint(1, self.N, (y.shape[0],))
        z = torch.randn_like(y) + cond

        tn_1 = self.t_steps[n + 1].reshape(-1, 1, 1).to(y.device)
        f_theta = self.EDMPrecond(y + tn_1 * z, tn_1, cond, self.denoise_fn, mask)

        with torch.no_grad():
            tn = self.t_steps[n].reshape(-1, 1, 1).to(y.device)

            # euler step
            x_hat = y + tn_1 * z
            denoised = self.EDMPrecond(
                x_hat, tn_1, cond, self.denoise_fn_pretrained, mask
            )
            d_cur = (x_hat - denoised) / tn_1
            y_tn = x_hat + (tn - tn_1) * d_cur

            f_theta_ema = self.EDMPrecond(y_tn, tn, cond, self.denoise_fn_ema, mask)

        # loss = (f_theta - f_theta_ema.detach()) ** 2
        # loss = torch.sum(loss * mask) / torch.sum(mask)
        loss = self.ssim_loss(f_theta, f_theta_ema.detach())
        loss = torch.sum(loss * mask) / torch.sum(mask)

        return loss

    def get_t_steps(self, N):
        N = N + 1
        step_indices = torch.arange(N)  # , device=latents.device)
        t_steps = (
            self.sigma_min ** (1 / self.rho)
            + step_indices
            / (N - 1)
            * (self.sigma_max ** (1 / self.rho) - self.sigma_min ** (1 / self.rho))
        ) ** self.rho

        return t_steps.flip(0)

    def CT_sampler(self, latents, cond, nonpadding, t_steps=1):
        """
        consistency distillation sampler

        Args:
            latents: noisy mel-spectrogram [B x n_mel x L]
            cond: output of conformer encoder [B x n_mel x L]
            nonpadding: mask of padded frames [B x n_mel x L]
            t_steps: number of steps for diffusion inference

        Returns:
            denoised mel-spectrogram [B x n_mel x L]
        """
        # one-step
        if t_steps == 1:
            t_steps = [80]
        # multi-step
        else:
            t_steps = self.get_t_steps(t_steps)

        t_steps = torch.as_tensor(t_steps).to(latents.device)
        latents = latents * t_steps[0]
        _t = torch.zeros((latents.shape[0], 1, 1), device=latents.device)
        _t[:, 0, 0] = t_steps
        x = self.EDMPrecond(latents, _t, cond, self.denoise_fn_ema, nonpadding)

        for t in t_steps[1:-1]:
            z = torch.randn_like(x) + cond
            x_tn = x + (t**2 - self.sigma_min**2).sqrt() * z
            _t = torch.zeros((x.shape[0], 1, 1), device=x.device)
            _t[:, 0, 0] = t
            t = _t
            print(t)
            x = self.EDMPrecond(x_tn, t, cond, self.denoise_fn_ema, nonpadding)
        return x

    def forward(self, x, nonpadding, cond, t_steps=1, infer=False):
        """
        calculate loss or sample mel-spectrogram

        Args:
            x:
                training: ground truth mel-spectrogram [B x n_mel x L]
                inference: output of encoder [B x n_mel x L]
        """
        if self.teacher:  # teacher model -- karras diffusion
            if not infer:
                loss = self.EDMLoss(x, cond, nonpadding)
                return loss
            else:
                shape = (cond.shape[0], self.cfg.n_mel, cond.shape[2])
                x = torch.randn(shape, device=x.device) + cond
                x = self.edm_sampler(x, cond, nonpadding, t_steps)

            return x
        else:  # Consistency distillation
            if not infer:
                loss = self.CTLoss_D(x, cond, nonpadding)
                return loss

            else:
                shape = (cond.shape[0], self.cfg.n_mel, cond.shape[2])
                x = torch.randn(shape, device=x.device) + cond
                x = self.CT_sampler(x, cond, nonpadding, t_steps=1)

            return x


class ComoSVC(BaseModule):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.cfg.model.comosvc.n_mel = self.cfg.preprocess.n_mel
        self.distill = self.cfg.model.comosvc.distill
        self.encoder = Conformer(self.cfg.model.comosvc)
        self.decoder = Consistency(self.cfg, distill=self.distill)
        self.ssim_loss = SSIM()

    @torch.no_grad()
    def forward(self, x_mask, x, n_timesteps, temperature=1.0):
        """
        Generates mel-spectrogram from pitch, content vector, energy. Returns:
            1. encoder outputs (from conformer)
            2. decoder outputs (from diffusion-based decoder)

        Args:
            x_mask : mask of padded frames in mel-spectrogram. [B x L x n_mel]
            x : output of encoder framework. [B x L x d_condition]
            n_timesteps : number of steps to use for reverse diffusion in decoder.
            temperature : controls variance of terminal distribution.
        """

        # Get encoder_outputs `mu_x`
        mu_x = self.encoder(x, x_mask)
        encoder_outputs = mu_x

        mu_x = mu_x.transpose(1, 2)
        x_mask = x_mask.transpose(1, 2)

        # Generate sample by performing reverse dynamics
        decoder_outputs = self.decoder(
            mu_x, x_mask, mu_x, t_steps=n_timesteps, infer=True
        )
        decoder_outputs = decoder_outputs.transpose(1, 2)
        return encoder_outputs, decoder_outputs

    def compute_loss(self, x_mask, x, mel, out_size=None, skip_diff=False):
        """
        Computes 2 losses:
            1. prior loss: loss between mel-spectrogram and encoder outputs.
            2. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder.

        Args:
            x_mask : mask of padded frames in mel-spectrogram. [B x L x n_mel]
            x : output of encoder framework. [B x L x d_condition]
            mel : ground truth mel-spectrogram. [B x L x n_mel]
        """

        mu_x = self.encoder(x, x_mask)
        # prior loss
        prior_loss = torch.sum(
            0.5 * ((mel - mu_x) ** 2 + math.log(2 * math.pi)) * x_mask
        )
        prior_loss = prior_loss / (torch.sum(x_mask) * self.cfg.model.comosvc.n_mel)
        # ssim loss
        ssim_loss = self.ssim_loss(mu_x, mel)
        ssim_loss = torch.sum(ssim_loss * x_mask) / torch.sum(x_mask)

        x_mask = x_mask.transpose(1, 2)
        mu_x = mu_x.transpose(1, 2)
        mel = mel.transpose(1, 2)
        if not self.distill and skip_diff:
            diff_loss = prior_loss.clone()
            diff_loss.fill_(0)

        # Cut a small segment of mel-spectrogram in order to increase batch size
        else:
            if self.distill:
                mu_y = mu_x.detach()
            else:
                mu_y = mu_x
            mask_y = x_mask

            diff_loss = self.decoder(mel, mask_y, mu_y, infer=False)

        return ssim_loss, prior_loss, diff_loss