import os
from argparse import Namespace
import numpy as np
import torch

from models.StyleGANControler import StyleGANControler


class Model:
    def __init__(
        self, checkpoint_path, truncation=0.5, use_average_code_as_input=False
    ):
        self.truncation = truncation
        self.use_average_code_as_input = use_average_code_as_input
        ckpt = torch.load(checkpoint_path, map_location="cpu")
        opts = ckpt["opts"]
        opts["checkpoint_path"] = checkpoint_path
        self.opts = Namespace(**ckpt["opts"])
        self.net = StyleGANControler(self.opts)
        self.net.eval()
        self.net.cuda()
        self.target_layers = [0, 1, 2, 3, 4, 5]

    def random_sample(self):
        z1 = torch.randn(1, 512).to("cuda")
        x1, w1, f1 = self.net.decoder(
            [z1],
            input_is_latent=False,
            randomize_noise=False,
            return_feature_map=True,
            return_latents=True,
            truncation=self.truncation,
            truncation_latent=self.net.latent_avg[0],
        )
        w1_initial = w1.clone()
        x1 = self.net.face_pool(x1)
        image = (
            ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
        )
        return (
            image,
            {
                "w1": w1.cpu().detach().numpy(),
                "w1_initial": w1_initial.cpu().detach().numpy(),
            },
        )  # return latent vector along with the image

    def latents_to_tensor(self, latents):
        w1 = latents["w1"]
        w1_initial = latents["w1_initial"]

        w1 = torch.tensor(w1).to("cuda")
        w1_initial = torch.tensor(w1_initial).to("cuda")

        x1, w1, f1 = self.net.decoder(
            [w1],
            input_is_latent=True,
            randomize_noise=False,
            return_feature_map=True,
            return_latents=True,
        )
        x1, w1_initial, f1 = self.net.decoder(
            [w1_initial],
            input_is_latent=True,
            randomize_noise=False,
            return_feature_map=True,
            return_latents=True,
        )

        return (w1, w1_initial, f1)

    def transform(
        self,
        latents,
        dz,
        dxy,
        sxsy=[0, 0],
        stop_points=[],
        zoom_in=False,
        zoom_out=False,
    ):
        w1, w1_initial, f1 = self.latents_to_tensor(latents)
        w1 = w1_initial.clone()

        dxyz = np.array([dxy[0], dxy[1], dz], dtype=np.float32)
        dxy_norm = np.linalg.norm(dxyz[:2], ord=2)
        epsilon = 1e-8 
        dxy_norm = dxy_norm + epsilon
        dxyz[:2] = dxyz[:2] / dxy_norm
        vec_num = dxy_norm / 10

        x = torch.from_numpy(np.array([[dxyz]], dtype=np.float32)).cuda()
        f1 = torch.nn.functional.interpolate(f1, (256, 256))
        y = f1[:, :, sxsy[1], sxsy[0]].unsqueeze(0)

        if len(stop_points) > 0:
            x = torch.cat(
                [x, torch.zeros(x.shape[0], len(stop_points), x.shape[2]).cuda()], dim=1
            )
            tmp = []
            for sp in stop_points:
                tmp.append(f1[:, :, sp[1], sp[0]].unsqueeze(1))
            y = torch.cat([y, torch.cat(tmp, dim=1)], dim=1)

        if not self.use_average_code_as_input:
            w_hat = self.net.encoder(
                w1[:, self.target_layers].detach(),
                x.detach(),
                y.detach(),
                alpha=vec_num,
            )
            w1 = w1.clone()
            w1[:, self.target_layers] = w_hat
        else:
            w_hat = self.net.encoder(
                self.net.latent_avg.unsqueeze(0)[:, self.target_layers].detach(),
                x.detach(),
                y.detach(),
                alpha=vec_num,
            )
            w1 = w1.clone()
            w1[:, self.target_layers] = (
                w1.clone()[:, self.target_layers]
                + w_hat
                - self.net.latent_avg.unsqueeze(0)[:, self.target_layers]
            )

        x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False)

        x1 = self.net.face_pool(x1)
        result = (
            ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
        )
        return (
            result,
            {
                "w1": w1.cpu().detach().numpy(),
                "w1_initial": w1_initial.cpu().detach().numpy(),
            },
        )

    def change_style(self, latents):
        w1, w1_initial, f1 = self.latents_to_tensor(latents)
        w1 = w1_initial.clone()

        z1 = torch.randn(1, 512).to("cuda")
        x1, w2 = self.net.decoder(
            [z1],
            input_is_latent=False,
            randomize_noise=False,
            return_latents=True,
            truncation=self.truncation,
            truncation_latent=self.net.latent_avg[0],
        )
        w1[:, 6:] = w2.detach()[:, 0]
        x1, w1_new = self.net.decoder(
            [w1],
            input_is_latent=True,
            randomize_noise=False,
            return_latents=True,
        )
        result = (
            ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
        )
        return (
            result,
            {
                "w1": w1_new.cpu().detach().numpy(),
                "w1_initial": w1_new.cpu().detach().numpy(),
            },
        )

    def reset(self, latents):
        w1, w1_initial, f1 = self.latents_to_tensor(latents)
        x1, w1_new, f1 = self.net.decoder(
            [w1_initial],
            input_is_latent=True,
            randomize_noise=False,
            return_feature_map=True,
            return_latents=True,
        )
        result = (
            ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
        )
        return (
            result,
            {
                "w1": w1_new.cpu().detach().numpy(),
                "w1_initial": w1_new.cpu().detach().numpy(),
            },
        )