# -*- coding: utf-8 -*-
# Copyright (c) XiMing Xing. All rights reserved.
# Author: XiMing Xing
# Description:
import shutil
from PIL import Image
from pathlib import Path

import torch
from torchvision import transforms
import clip
from tqdm.auto import tqdm
import numpy as np

from pytorch_svgrender.libs.engine import ModelState
from pytorch_svgrender.painter.style_clipdraw import (
    Painter, PainterOptimizer, VGG16Extractor, StyleLoss, sample_indices
)
from pytorch_svgrender.plt import plot_img, plot_couple


class StyleCLIPDrawPipeline(ModelState):

    def __init__(self, args):
        logdir_ = f"sd{args.seed}" \
                  f"-P{args.x.num_paths}" \
                  f"-style{args.x.style_strength}" \
                  f"-n{args.x.num_aug}"
        super().__init__(args, log_path_suffix=logdir_)

        # create log dir
        self.png_logs_dir = self.result_path / "png_logs"
        self.svg_logs_dir = self.result_path / "svg_logs"
        if self.accelerator.is_main_process:
            self.png_logs_dir.mkdir(parents=True, exist_ok=True)
            self.svg_logs_dir.mkdir(parents=True, exist_ok=True)

        # make video log
        self.make_video = self.args.mv
        if self.make_video:
            self.frame_idx = 0
            self.frame_log_dir = self.result_path / "frame_logs"
            self.frame_log_dir.mkdir(parents=True, exist_ok=True)

        self.clip, self.tokenize_fn = self.init_clip()

        self.style_extractor = VGG16Extractor(space="normal").to(self.device)
        self.style_loss = StyleLoss()

    def init_clip(self):
        model, _ = clip.load('ViT-B/32', self.device, jit=False)
        return model, clip.tokenize

    def drawing_augment(self, image):
        augment_trans = transforms.Compose([
            transforms.RandomPerspective(fill=1, p=1, distortion_scale=0.5),
            transforms.RandomResizedCrop(224, scale=(0.7, 0.9)),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
        ])

        # image augmentation transformation
        img_augs = []
        for n in range(self.x_cfg.num_aug):
            img_augs.append(augment_trans(image))
        im_batch = torch.cat(img_augs)
        # clip visual encoding
        image_features = self.clip.encode_image(im_batch)

        return image_features

    def style_file_preprocess(self, style_file):
        process_comp = transforms.Compose([
            transforms.Resize(size=(224, 224)),
            transforms.ToTensor(),
            transforms.Lambda(lambda t: t.unsqueeze(0)),
            transforms.Lambda(lambda t: (t + 1) / 2),
        ])
        style_file = process_comp(style_file)
        style_file = style_file.to(self.device)
        return style_file

    def painterly_rendering(self, prompt, style_fpath):
        # load style file
        style_path = Path(style_fpath)
        assert style_path.exists(), f"{style_fpath} is not exist!"
        self.print(f"load style file from: {style_path.as_posix()}")
        style_pil = Image.open(style_path.as_posix()).convert("RGB")
        style_img = self.style_file_preprocess(style_pil)
        shutil.copy(style_fpath, self.result_path)  # copy style file

        # extract style features from style image
        feat_style = None
        for i in range(5):
            with torch.no_grad():
                # r is region of interest (mask)
                feat_e = self.style_extractor.forward_samples_hypercolumn(style_img, samps=1000)
                feat_style = feat_e if feat_style is None else torch.cat((feat_style, feat_e), dim=2)

        # text prompt encoding
        self.print(f"prompt: {prompt}")
        text_tokenize = self.tokenize_fn(prompt).to(self.device)
        with torch.no_grad():
            text_features = self.clip.encode_text(text_tokenize)

        renderer = Painter(self.x_cfg,
                           self.args.diffvg,
                           num_strokes=self.x_cfg.num_paths,
                           canvas_size=self.x_cfg.image_size,
                           device=self.device)
        img = renderer.init_image(stage=0)
        self.print("init_image shape: ", img.shape)
        plot_img(img, self.result_path, fname="init_img")

        optimizer = PainterOptimizer(renderer, self.x_cfg.lr, self.x_cfg.width_lr, self.x_cfg.color_lr)
        optimizer.init_optimizers()

        style_weight = 4 * (self.x_cfg.style_strength / 100)
        self.print(f'style_weight: {style_weight}')

        total_step = self.x_cfg.num_iter
        with tqdm(initial=self.step, total=total_step, disable=not self.accelerator.is_main_process) as pbar:
            while self.step < total_step:
                rendering = renderer.get_image(self.step).to(self.device)

                if self.make_video and (self.step % self.args.framefreq == 0 or self.step == total_step - 1):
                    plot_img(rendering, self.frame_log_dir, fname=f"iter{self.frame_idx}")
                    self.frame_idx += 1

                rendering_aug = self.drawing_augment(rendering)

                loss = torch.tensor(0., device=self.device)

                # do clip optimization
                if self.step < 0.9 * total_step:
                    for n in range(self.x_cfg.num_aug):
                        loss -= torch.cosine_similarity(text_features, rendering_aug[n:n + 1], dim=1).mean()

                # do style optimization
                # extract style features based on the approach from STROTSS [Kolkin et al., 2019].
                feat_content = self.style_extractor(rendering)

                xx, xy = sample_indices(feat_content[0], feat_style)

                np.random.shuffle(xx)
                np.random.shuffle(xy)

                L_style = self.style_loss.forward(feat_content, feat_content, feat_style, [xx, xy], 0)

                loss += L_style * style_weight

                pbar.set_description(
                    f"lr: {optimizer.get_lr():.3f}, "
                    f"L_train: {loss.item():.4f}, "
                    f"L_style: {L_style.item():.4f}"
                )

                # optimization
                optimizer.zero_grad_()
                loss.backward()
                optimizer.step_()

                renderer.clip_curve_shape()

                if self.x_cfg.lr_schedule:
                    optimizer.update_lr(self.step)

                if self.step % self.args.save_step == 0 and self.accelerator.is_main_process:
                    plot_couple(style_img,
                                rendering,
                                self.step,
                                prompt=prompt,
                                output_dir=self.png_logs_dir.as_posix(),
                                fname=f"iter{self.step}")
                    renderer.save_svg(self.svg_logs_dir.as_posix(), f"svg_iter{self.step}")

                self.step += 1
                pbar.update(1)

        plot_couple(style_img,
                    rendering,
                    self.step,
                    prompt=prompt,
                    output_dir=self.result_path.as_posix(),
                    fname=f"final_iter")
        renderer.save_svg(self.result_path.as_posix(), "final_svg")

        if self.make_video:
            from subprocess import call
            call([
                "ffmpeg",
                "-framerate", f"{self.args.framerate}",
                "-i", (self.frame_log_dir / "iter%d.png").as_posix(),
                "-vb", "20M",
                (self.result_path / "styleclipdraw_rendering.mp4").as_posix()
            ])

        self.close(msg="painterly rendering complete.")