# -*- coding: utf-8 -*-
# Copyright (c) XiMing Xing. All rights reserved.
# Author: XiMing Xing
# Description:
import pathlib
from PIL import Image
from typing import AnyStr

import numpy as np
from tqdm.auto import tqdm
import torch
from torch.optim.lr_scheduler import LambdaLR
import torchvision
from torchvision import transforms

from pytorch_svgrender.libs.engine import ModelState
from pytorch_svgrender.libs.solver.optim import get_optimizer
from pytorch_svgrender.painter.svgdreamer import Painter, PainterOptimizer
from pytorch_svgrender.painter.svgdreamer.painter_params import CosineWithWarmupLRLambda
from pytorch_svgrender.painter.live import xing_loss_fn
from pytorch_svgrender.painter.svgdreamer import VectorizedParticleSDSPipeline
from pytorch_svgrender.plt import plot_img
from pytorch_svgrender.utils.color_attrs import init_tensor_with_color
from pytorch_svgrender.token2attn.ptp_utils import view_images
from pytorch_svgrender.diffusers_warp import model2res

import ImageReward as RM


class SVGDreamerPipeline(ModelState):

    def __init__(self, args):
        assert args.x.style in ["iconography", "pixelart", "low-poly", "painting", "sketch", "ink"]
        assert args.x.guidance.n_particle >= args.x.guidance.vsd_n_particle
        assert args.x.guidance.n_particle >= args.x.guidance.phi_n_particle
        assert args.x.guidance.n_phi_sample >= 1

        logdir_ = f"sd{args.seed}" \
                  f"-{'vpsd' if args.x.skip_sive else 'sive'}" \
                  f"-{args.x.model_id}" \
                  f"-{args.x.style}" \
                  f"-P{args.x.num_paths}" \
                  f"{'-RePath' if args.x.path_reinit.use else ''}"
        super().__init__(args, log_path_suffix=logdir_)

        # create log dir
        self.png_logs_dir = self.result_path / "png_logs"
        self.svg_logs_dir = self.result_path / "svg_logs"
        self.ft_png_logs_dir = self.result_path / "ft_png_logs"
        self.ft_svg_logs_dir = self.result_path / "ft_svg_logs"
        self.sd_sample_dir = self.result_path / 'sd_samples'
        self.reinit_dir = self.result_path / "reinit_logs"
        self.init_stage_two_dir = self.result_path / "stage_two_init_logs"
        self.phi_samples_dir = self.result_path / "phi_sampling_logs"

        if self.accelerator.is_main_process:
            self.png_logs_dir.mkdir(parents=True, exist_ok=True)
            self.svg_logs_dir.mkdir(parents=True, exist_ok=True)
            self.ft_png_logs_dir.mkdir(parents=True, exist_ok=True)
            self.ft_svg_logs_dir.mkdir(parents=True, exist_ok=True)
            self.sd_sample_dir.mkdir(parents=True, exist_ok=True)
            self.reinit_dir.mkdir(parents=True, exist_ok=True)
            self.init_stage_two_dir.mkdir(parents=True, exist_ok=True)
            self.phi_samples_dir.mkdir(parents=True, exist_ok=True)

        self.select_fpth = self.result_path / 'select_sample.png'

        # make video log
        self.make_video = self.args.mv
        if self.make_video:
            self.frame_idx = 0
            self.frame_log_dir = self.result_path / "frame_logs"
            self.frame_log_dir.mkdir(parents=True, exist_ok=True)

        self.g_device = torch.Generator(device=self.device).manual_seed(args.seed)

        self.pipeline = VectorizedParticleSDSPipeline(args, args.diffuser, self.x_cfg.guidance, self.device)

        # load reward model
        self.reward_model = None
        if self.x_cfg.guidance.phi_ReFL:
            self.reward_model = RM.load("ImageReward-v1.0", device=self.device, download_root=self.x_cfg.reward_path)

        self.style = self.x_cfg.style
        if self.style == "pixelart":
            self.x_cfg.lr_stage_one.lr_schedule = False
            self.x_cfg.lr_stage_two.lr_schedule = False

    def target_file_preprocess(self, tar_path: AnyStr):
        process_comp = transforms.Compose([
            transforms.Resize(size=(self.x_cfg.image_size, self.x_cfg.image_size)),
            transforms.ToTensor(),
            transforms.Lambda(lambda t: t.unsqueeze(0)),
        ])

        tar_pil = Image.open(tar_path).convert("RGB")  # open file
        target_img = process_comp(tar_pil)  # preprocess
        target_img = target_img.to(self.device)
        return target_img

    def SIVE_stage(self, text_prompt: str):
        # TODO: SIVE implementation
        pass

    def painterly_rendering(self, text_prompt: str, target_file: AnyStr = None):
        # log prompts
        self.print(f"prompt: {text_prompt}")
        self.print(f"neg_prompt: {self.args.neg_prompt}\n")

        # for convenience
        im_size = self.x_cfg.image_size
        guidance_cfg = self.x_cfg.guidance
        n_particle = self.x_cfg.guidance.n_particle
        total_step = self.x_cfg.guidance.num_iter
        path_reinit = self.x_cfg.path_reinit

        init_from_target = True if (target_file and pathlib.Path(target_file).exists()) else False
        # switch mode
        if self.x_cfg.skip_sive and not init_from_target:
            # mode 1: optimization with VPSD from scratch
            # randomly init
            self.print("optimization with VPSD from scratch...")
            if self.x_cfg.color_init == 'rand':
                target_img = torch.randn(1, 3, im_size, im_size)
                self.print("color: randomly init")
            else:
                target_img = init_tensor_with_color(self.x_cfg.color_init, 1, im_size, im_size)
                self.print(f"color: {self.x_cfg.color_init}")

            # log init target_img
            plot_img(target_img, self.result_path, fname='init_target_img')
            final_svg_path = None
        elif init_from_target:
            # mode 2: load the SVG file and finetune it
            self.print(f"load svg from {target_file} ...")
            self.print(f"SVG fine-tuning via VPSD...")
            final_svg_path = target_file
            if self.x_cfg.color_init == 'target_randn':
                # special order: init newly paths color use random color
                target_img = torch.randn(1, 3, im_size, im_size)
                self.print("color: randomly init")
            else:
                # load the SVG and init newly paths color use target_img
                # note: the target will be converted to png via pydiffvg when load_renderer called
                target_img = None
        else:
            # mode 3: text-to-img-to-svg (two stage)
            target_img, final_svg_path = self.SIVE_stage(text_prompt)
            self.x_cfg.path_svg = final_svg_path
            self.print("\n SVG fine-tuning via VPSD...")
            plot_img(target_img, self.result_path, fname='init_target_img')

        # create svg renderer
        renderers = [self.load_renderer(final_svg_path) for _ in range(n_particle)]

        # randomly initialize the particles
        if self.x_cfg.skip_sive or init_from_target:
            if target_img is None:
                target_img = self.target_file_preprocess(self.result_path / 'target_img.png')
            for render in renderers:
                render.component_wise_path_init(gt=target_img, pred=None, init_type='random')

        # log init images
        for i, r in enumerate(renderers):
            init_imgs = r.init_image(stage=0, num_paths=self.x_cfg.num_paths)
            plot_img(init_imgs, self.init_stage_two_dir, fname=f"init_img_stage_two_{i}")

        # init renderer optimizer
        optimizers = []
        for renderer in renderers:
            optim_ = PainterOptimizer(renderer,
                                      self.style,
                                      guidance_cfg.num_iter,
                                      self.x_cfg.lr_stage_two,
                                      self.x_cfg.trainable_bg)
            optim_.init_optimizers()
            optimizers.append(optim_)

        # init phi_model optimizer
        phi_optimizer = get_optimizer('adamW',
                                      self.pipeline.phi_params,
                                      guidance_cfg.phi_lr,
                                      guidance_cfg.phi_optim)
        # init phi_model lr scheduler
        phi_scheduler = None
        schedule_cfg = guidance_cfg.phi_schedule
        if schedule_cfg.use:
            phi_lr_lambda = CosineWithWarmupLRLambda(num_steps=schedule_cfg.total_step,
                                                     warmup_steps=schedule_cfg.warmup_steps,
                                                     warmup_start_lr=schedule_cfg.warmup_start_lr,
                                                     warmup_end_lr=schedule_cfg.warmup_end_lr,
                                                     cosine_end_lr=schedule_cfg.cosine_end_lr)
            phi_scheduler = LambdaLR(phi_optimizer, lr_lambda=phi_lr_lambda, last_epoch=-1)

        self.print(f"-> Painter point Params: {len(renderers[0].get_point_parameters())}")
        self.print(f"-> Painter color Params: {len(renderers[0].get_color_parameters())}")
        self.print(f"-> Painter width Params: {len(renderers[0].get_width_parameters())}")

        L_reward = torch.tensor(0.)

        self.step = 0  # reset global step
        self.print(f"\ntotal VPSD optimization steps: {total_step}")
        with tqdm(initial=self.step, total=total_step, disable=not self.accelerator.is_main_process) as pbar:
            while self.step < total_step:
                # set particles
                particles = [renderer.get_image() for renderer in renderers]
                raster_imgs = torch.cat(particles, dim=0)

                if self.make_video and (self.step % self.args.framefreq == 0 or self.step == total_step - 1):
                    plot_img(raster_imgs, self.frame_log_dir, fname=f"iter{self.frame_idx}")
                    self.frame_idx += 1

                L_guide, grad, latents, t_step = self.pipeline.variational_score_distillation(
                    raster_imgs,
                    self.step,
                    prompt=[text_prompt],
                    negative_prompt=self.args.neg_prompt,
                    grad_scale=guidance_cfg.grad_scale,
                    enhance_particle=guidance_cfg.particle_aug,
                    im_size=model2res(self.x_cfg.model_id)
                )

                # Xing Loss for Self-Interaction Problem
                L_add = torch.tensor(0.)
                if self.style == "iconography" or self.x_cfg.xing_loss.use:
                    for r in renderers:
                        L_add += xing_loss_fn(r.get_point_parameters()) * self.x_cfg.xing_loss.weight

                loss = L_guide + L_add

                # optimization
                for opt_ in optimizers:
                    opt_.zero_grad_()
                loss.backward()
                for opt_ in optimizers:
                    opt_.step_()

                # phi_model optimization
                for _ in range(guidance_cfg.phi_update_step):
                    L_lora = self.pipeline.train_phi_model(latents, guidance_cfg.phi_t, as_latent=True)

                    phi_optimizer.zero_grad()
                    L_lora.backward()
                    phi_optimizer.step()

                # reward learning
                if guidance_cfg.phi_ReFL and self.step % guidance_cfg.phi_sample_step == 0:
                    with torch.no_grad():
                        phi_outputs = []
                        phi_sample_paths = []
                        for idx in range(guidance_cfg.n_phi_sample):
                            phi_output = self.pipeline.sample(text_prompt,
                                                              num_inference_steps=guidance_cfg.phi_infer_step,
                                                              generator=self.g_device)
                            sample_path = (self.phi_samples_dir / f'iter{idx}.png').as_posix()
                            phi_output.images[0].save(sample_path)
                            phi_sample_paths.append(sample_path)

                            phi_output_np = np.array(phi_output.images[0])
                            phi_outputs.append(phi_output_np)
                        # save all samples
                        view_images(phi_outputs, save_image=True,
                                    num_rows=max(len(phi_outputs) // 6, 1),
                                    fp=self.phi_samples_dir / f'samples_iter{self.step}.png')

                    ranking, rewards = self.reward_model.inference_rank(text_prompt, phi_sample_paths)
                    self.print(f"ranking: {ranking}, reward score: {rewards}")

                    for k in range(guidance_cfg.n_phi_sample):
                        phi = self.target_file_preprocess(phi_sample_paths[ranking[k] - 1])
                        L_reward = self.pipeline.train_phi_model_refl(phi, weight=rewards[k])

                        phi_optimizer.zero_grad()
                        L_reward.backward()
                        phi_optimizer.step()

                # update the learning rate of the phi_model
                if phi_scheduler is not None:
                    phi_scheduler.step()

                # curve regularization
                for r in renderers:
                    r.clip_curve_shape()

                # re-init paths
                if self.step % path_reinit.freq == 0 and self.step < path_reinit.stop_step and self.step != 0:
                    for i, r in enumerate(renderers):
                        r.reinitialize_paths(path_reinit.use,  # on-off
                                             path_reinit.opacity_threshold,
                                             path_reinit.area_threshold,
                                             fpath=self.reinit_dir / f"reinit-{self.step}_p{i}.svg")

                # update lr
                if self.x_cfg.lr_stage_two.lr_schedule:
                    for opt_ in optimizers:
                        opt_.update_lr()

                # log pretrained model lr
                lr_str = ""
                for k, lr in optimizers[0].get_lr().items():
                    lr_str += f"{k}_lr: {lr:.4f}, "
                # log phi model lr
                cur_phi_lr = phi_optimizer.param_groups[0]['lr']
                lr_str += f"phi_lr: {cur_phi_lr:.3e}, "

                pbar.set_description(
                    lr_str +
                    f"t: {t_step.item():.2f}, "
                    f"L_total: {loss.item():.4f}, "
                    f"L_add: {L_add.item():.4e}, "
                    f"L_lora: {L_lora.item():.4f}, "
                    f"L_reward: {L_reward.item():.4f}, "
                    f"vpsd: {grad.item():.4e}"
                )

                if self.step % self.args.save_step == 0 and self.accelerator.is_main_process:
                    # save png
                    torchvision.utils.save_image(raster_imgs,
                                                 fp=self.ft_png_logs_dir / f'iter{self.step}.png')

                    # save svg
                    for i, r in enumerate(renderers):
                        r.pretty_save_svg(self.ft_svg_logs_dir / f"svg_iter{self.step}_p{i}.svg")

                self.step += 1
                pbar.update(1)

        # save final
        for i, r in enumerate(renderers):
            final_svg_path = self.result_path / f"finetune_final_p_{i}.svg"
            r.pretty_save_svg(final_svg_path)
        # save SVGs
        torchvision.utils.save_image(raster_imgs, fp=self.result_path / f'all_particles.png')

        if self.make_video:
            from subprocess import call
            call([
                "ffmpeg",
                "-framerate", f"{self.args.framerate}",
                "-i", (self.frame_log_dir / "iter%d.png").as_posix(),
                "-vb", "20M",
                (self.result_path / "svgdreamer_rendering.mp4").as_posix()
            ])

        self.close(msg="painterly rendering complete.")

    def load_renderer(self, path_svg=None):
        renderer = Painter(self.args.diffvg,
                           self.style,
                           self.x_cfg.num_segments,
                           self.x_cfg.segment_init,
                           self.x_cfg.radius,
                           self.x_cfg.image_size,
                           self.x_cfg.grid,
                           self.x_cfg.trainable_bg,
                           self.x_cfg.width,
                           path_svg=path_svg,
                           device=self.device)

        # if load a svg file, then rasterize it
        save_path = self.result_path / 'target_img.png'
        if path_svg is not None and (not save_path.exists()):
            canvas_width, canvas_height, shapes, shape_groups = renderer.load_svg(path_svg)
            render_img = renderer.render_image(canvas_width, canvas_height, shapes, shape_groups)
            torchvision.utils.save_image(render_img, fp=save_path)
        return renderer