# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import os
from collections import OrderedDict

import torch
from tqdm import trange

from scepter.modules.model.registry import (DIFFUSION_SAMPLERS, DIFFUSIONS,
                                            NOISE_SCHEDULERS)
from scepter.modules.utils.config import Config, dict_to_yaml
from scepter.modules.utils.distribute import we
from scepter.modules.utils.file_system import FS


@DIFFUSIONS.register_class()
class ACEDiffusion(object):
    para_dict = {
        'NOISE_SCHEDULER': {},
        'SAMPLER_SCHEDULER': {},
        'MIN_SNR_GAMMA': {
            'value': None,
            'description': 'The minimum SNR gamma value for the loss function.'
        },
        'PREDICTION_TYPE': {
            'value': 'eps',
            'description':
            'The type of prediction to use for the loss function.'
        }
    }

    def __init__(self, cfg, logger=None):
        super(ACEDiffusion, self).__init__()
        self.logger = logger
        self.cfg = cfg
        self.init_params()

    def init_params(self):
        self.min_snr_gamma = self.cfg.get('MIN_SNR_GAMMA', None)
        self.prediction_type = self.cfg.get('PREDICTION_TYPE', 'eps')
        self.noise_scheduler = NOISE_SCHEDULERS.build(self.cfg.NOISE_SCHEDULER,
                                                      logger=self.logger)
        self.sampler_scheduler = NOISE_SCHEDULERS.build(self.cfg.get(
            'SAMPLER_SCHEDULER', self.cfg.NOISE_SCHEDULER),
                                                        logger=self.logger)
        self.num_timesteps = self.noise_scheduler.num_timesteps
        if self.cfg.have('WORK_DIR') and we.rank == 0:
            schedule_visualization = os.path.join(self.cfg.WORK_DIR,
                                                  'noise_schedule.png')
            with FS.put_to(schedule_visualization) as local_path:
                self.noise_scheduler.plot_noise_sampling_map(local_path)
            schedule_visualization = os.path.join(self.cfg.WORK_DIR,
                                                  'sampler_schedule.png')
            with FS.put_to(schedule_visualization) as local_path:
                self.sampler_scheduler.plot_noise_sampling_map(local_path)

    def sample(self,
               noise,
               model,
               model_kwargs={},
               steps=20,
               sampler=None,
               use_dynamic_cfg=False,
               guide_scale=None,
               guide_rescale=None,
               show_progress=False,
               return_intermediate=None,
               intermediate_callback=None,
               **kwargs):
        assert isinstance(steps, (int, torch.LongTensor))
        assert return_intermediate in (None, 'x0', 'xt')
        assert isinstance(sampler, (str, dict, Config))
        intermediates = []

        def callback_fn(x_t, t, sigma=None, alpha=None):
            timestamp = t
            t = t.repeat(len(x_t)).round().long().to(x_t.device)
            sigma = sigma.repeat(len(x_t), *([1] * (len(sigma.shape) - 1)))
            alpha = alpha.repeat(len(x_t), *([1] * (len(alpha.shape) - 1)))

            if guide_scale is None or guide_scale == 1.0:
                out = model(x=x_t, t=t, **model_kwargs)
            else:
                if use_dynamic_cfg:
                    guidance_scale = 1 + guide_scale * (
                        (1 - math.cos(math.pi * (
                            (steps - timestamp.item()) / steps)**5.0)) / 2)
                else:
                    guidance_scale = guide_scale
                y_out = model(x=x_t, t=t, **model_kwargs[0])
                u_out = model(x=x_t, t=t, **model_kwargs[1])
                out = u_out + guidance_scale * (y_out - u_out)
            if guide_rescale is not None and guide_rescale > 0.0:
                ratio = (
                    y_out.flatten(1).std(dim=1) /
                    (out.flatten(1).std(dim=1) + 1e-12)).view((-1, ) + (1, ) *
                                                              (y_out.ndim - 1))
                out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0

            if self.prediction_type == 'x0':
                x0 = out
            elif self.prediction_type == 'eps':
                x0 = (x_t - sigma * out) / alpha
            elif self.prediction_type == 'v':
                x0 = alpha * x_t - sigma * out
            else:
                raise NotImplementedError(
                    f'prediction_type {self.prediction_type} not implemented')

            return x0

        sampler_ins = self.get_sampler(sampler)

        # this is ignored for schnell
        sampler_output = sampler_ins.preprare_sampler(
            noise,
            steps=steps,
            prediction_type=self.prediction_type,
            scheduler_ins=self.sampler_scheduler,
            callback_fn=callback_fn)

        for _ in trange(steps, disable=not show_progress):
            trange.desc = sampler_output.msg
            sampler_output = sampler_ins.step(sampler_output)
            if return_intermediate == 'x_0':
                intermediates.append(sampler_output.x_0)
            elif return_intermediate == 'x_t':
                intermediates.append(sampler_output.x_t)
            if intermediate_callback is not None:
                intermediate_callback(intermediates[-1])
        return (sampler_output.x_0, intermediates
                ) if return_intermediate is not None else sampler_output.x_0

    def loss(self,
             x_0,
             model,
             model_kwargs={},
             reduction='mean',
             noise=None,
             **kwargs):
        # use noise scheduler to add noise
        if noise is None:
            noise = torch.randn_like(x_0)
        schedule_output = self.noise_scheduler.add_noise(x_0, noise, **kwargs)
        x_t, t, sigma, alpha = schedule_output.x_t, schedule_output.t, schedule_output.sigma, schedule_output.alpha
        out = model(x=x_t, t=t, **model_kwargs)

        # mse loss
        target = {
            'eps': noise,
            'x0': x_0,
            'v': alpha * noise - sigma * x_0
        }[self.prediction_type]

        loss = (out - target).pow(2)
        if reduction == 'mean':
            loss = loss.flatten(1).mean(dim=1)

        if self.min_snr_gamma is not None:
            alphas = self.noise_scheduler.alphas.to(x_0.device)[t]
            sigmas = self.noise_scheduler.sigmas.pow(2).to(x_0.device)[t]
            snrs = (alphas / sigmas).clamp(min=1e-20)
            min_snrs = snrs.clamp(max=self.min_snr_gamma)
            weights = min_snrs / snrs
        else:
            weights = 1

        loss = loss * weights
        return loss

    def get_sampler(self, sampler):
        if isinstance(sampler, str):
            if sampler not in DIFFUSION_SAMPLERS.class_map:
                if self.logger is not None:
                    self.logger.info(
                        f'{sampler} not in the defined samplers list {DIFFUSION_SAMPLERS.class_map.keys()}'
                    )
                else:
                    print(
                        f'{sampler} not in the defined samplers list {DIFFUSION_SAMPLERS.class_map.keys()}'
                    )
                return None
            sampler_cfg = Config(cfg_dict={'NAME': sampler}, load=False)
            sampler_ins = DIFFUSION_SAMPLERS.build(sampler_cfg,
                                                   logger=self.logger)
        elif isinstance(sampler, (Config, dict, OrderedDict)):
            if isinstance(sampler, (dict, OrderedDict)):
                sampler = Config(
                    cfg_dict={k.upper(): v
                              for k, v in dict(sampler).items()},
                    load=False)
            sampler_ins = DIFFUSION_SAMPLERS.build(sampler, logger=self.logger)
        else:
            raise NotImplementedError
        return sampler_ins

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}' + ' ' + super().__repr__()

    @staticmethod
    def get_config_template():
        return dict_to_yaml('DIFFUSIONS',
                            __class__.__name__,
                            ACEDiffusion.para_dict,
                            set_name=True)