import numpy as np
import torch
import torch.nn.functional as F
from typing import Iterable, Tuple, Union
from pathlib import Path
from torchvision import transforms
import kornia
from omegaconf import DictConfig

from src.FaceDetector.face_detector import Detection
from src.FaceAlign.face_align import align_face, inverse_transform_batch
from src.PostProcess.utils import SoftErosion
from src.model_loader import get_model
from src.Misc.types import CheckpointType, FaceAlignmentType
from src.Misc.utils import tensor2img


class SimSwap:
    def __init__(
        self,
        config: DictConfig,
        id_image: Union[np.ndarray, None] = None,
        specific_image: Union[np.ndarray, None] = None,
    ):

        self.id_image: Union[np.ndarray, None] = id_image
        self.id_latent: Union[torch.Tensor,  None] = None
        self.specific_id_image: Union[np.ndarray,  None] = specific_image
        self.specific_latent: Union[torch.Tensor,  None] = None

        self.use_mask: Union[bool, None] = True
        self.crop_size: Union[int, None] = None
        self.checkpoint_type: Union[CheckpointType,  None] = None
        self.face_alignment_type: Union[FaceAlignmentType,  None] = None
        self.smooth_mask_iter: Union[int,  None] = None
        self.smooth_mask_kernel_size: Union[int,  None] = None
        self.smooth_mask_threshold: Union[float,  None] = None
        self.face_detector_threshold: Union[float,  None] = None
        self.specific_latent_match_threshold: Union[float,  None] = None
        self.device = torch.device(config.device)

        self.set_parameters(config)

        # For BiSeNet and for official_224 SimSwap
        self.to_tensor_normalize = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]
        )

        # For SimSwap models trained with the updated code
        self.to_tensor = transforms.ToTensor()

        self.face_detector = get_model(
            "face_detector",
            device=self.device,
            load_state_dice=False,
            model_path=Path(config.face_detector_weights),
            det_thresh=self.face_detector_threshold,
            det_size=(640, 640),
            mode="ffhq",
        )

        self.face_id_net = get_model(
            "arcface",
            device=self.device,
            load_state_dice=False,
            model_path=Path(config.face_id_weights),
        )

        self.bise_net = get_model(
            "parsing_model",
            device=self.device,
            load_state_dice=True,
            model_path=Path(config.parsing_model_weights),
            n_classes=19,
        )

        gen_model = "generator_512" if self.crop_size == 512 else "generator_224"
        self.simswap_net = get_model(
            gen_model,
            device=self.device,
            load_state_dice=True,
            model_path=Path(config.simswap_weights),
            input_nc=3,
            output_nc=3,
            latent_size=512,
            n_blocks=9,
            deep=True if self.crop_size == 512 else False,
            use_last_act=True
            if self.checkpoint_type == CheckpointType.OFFICIAL_224
            else False,
        )

        self.blend = get_model(
            "blend_module",
            device=self.device,
            load_state_dice=False,
            model_path=Path(config.blend_module_weights)
        )

        self.enhance_output = config.enhance_output
        if config.enhance_output:
            self.gfpgan_net = get_model(
                "gfpgan",
                device=self.device,
                load_state_dice=True,
                model_path=Path(config.gfpgan_weights)
            )

    def set_parameters(self, config) -> None:
        self.set_crop_size(config.crop_size)
        self.set_checkpoint_type(config.checkpoint_type)
        self.set_face_alignment_type(config.face_alignment_type)
        self.set_face_detector_threshold(config.face_detector_threshold)
        self.set_specific_latent_match_threshold(config.specific_latent_match_threshold)
        self.set_smooth_mask_kernel_size(config.smooth_mask_kernel_size)
        self.set_smooth_mask_threshold(config.smooth_mask_threshold)
        self.set_smooth_mask_iter(config.smooth_mask_iter)

    def set_crop_size(self, crop_size: int) -> None:
        if crop_size < 0:
            raise "Invalid crop_size! Must be a positive value."

        self.crop_size = crop_size

    def set_checkpoint_type(self, checkpoint_type: str) -> None:
        type = CheckpointType(checkpoint_type)
        if type not in (CheckpointType.OFFICIAL_224, CheckpointType.UNOFFICIAL):
            raise "Invalid checkpoint_type! Must be one of the predefined values."

        self.checkpoint_type = type

    def set_face_alignment_type(self, face_alignment_type: str) -> None:
        type = FaceAlignmentType(face_alignment_type)
        if type not in (
            FaceAlignmentType.FFHQ,
            FaceAlignmentType.DEFAULT,
        ):
            raise "Invalid face_alignment_type! Must be one of the predefined values."

        self.face_alignment_type = type

    def set_face_detector_threshold(self, face_detector_threshold: float) -> None:
        if face_detector_threshold < 0.0 or face_detector_threshold > 1.0:
            raise "Invalid face_detector_threshold! Must be a positive value in range [0.0...1.0]."

        self.face_detector_threshold = face_detector_threshold

    def set_specific_latent_match_threshold(
        self, specific_latent_match_threshold: float
    ) -> None:
        if specific_latent_match_threshold < 0.0:
            raise "Invalid specific_latent_match_th! Must be a positive value."

        self.specific_latent_match_threshold = specific_latent_match_threshold

    def re_initialize_soft_mask(self):
        self.smooth_mask = SoftErosion(kernel_size=self.smooth_mask_kernel_size,
                                       threshold=self.smooth_mask_threshold,
                                       iterations=self.smooth_mask_iter).to(self.device)

    def set_smooth_mask_kernel_size(self, smooth_mask_kernel_size: int) -> None:
        if smooth_mask_kernel_size < 0:
            raise "Invalid smooth_mask_kernel_size! Must be a positive value."
        smooth_mask_kernel_size += 1 if smooth_mask_kernel_size % 2 == 0 else 0
        self.smooth_mask_kernel_size = smooth_mask_kernel_size
        self.re_initialize_soft_mask()

    def set_smooth_mask_threshold(self, smooth_mask_threshold: int) -> None:
        if smooth_mask_threshold < 0 or smooth_mask_threshold > 1.0:
            raise "Invalid smooth_mask_threshold! Must be within 0...1 range."
        self.smooth_mask_threshold = smooth_mask_threshold
        self.re_initialize_soft_mask()

    def set_smooth_mask_iter(self, smooth_mask_iter: float) -> None:
        if smooth_mask_iter < 0:
            raise "Invalid smooth_mask_iter! Must be a positive value.."
        self.smooth_mask_iter = smooth_mask_iter
        self.re_initialize_soft_mask()

    def run_detect_align(self, image: np.ndarray, for_id: bool = False) -> Tuple[Union[Iterable[np.ndarray], None],
                                                                                 Union[Iterable[np.ndarray], None],
                                                                                 np.ndarray]:
        detection: Detection = self.face_detector(image)

        if detection.bbox is None:
            if for_id:
                raise "Can't detect a face! Please change the ID image!"
            return None, None, detection.score

        kps = detection.key_points

        if for_id:
            max_score_ind = np.argmax(detection.score, axis=0)
            kps = detection.key_points[max_score_ind]
            kps = kps[None, ...]

        align_imgs, transforms = align_face(
            image,
            kps,
            crop_size=self.crop_size,
            mode="ffhq"
            if self.face_alignment_type == FaceAlignmentType.FFHQ
            else "none",
        )

        return align_imgs, transforms, detection.score

    def __call__(self, att_image: np.ndarray) -> np.ndarray:
        if self.id_latent is None:
            align_id_imgs, id_transforms, _ = self.run_detect_align(
                self.id_image, for_id=True
            )
            # normalize=True, because official SimSwap model trained with normalized id_lattent
            self.id_latent: torch.Tensor = self.face_id_net(
                align_id_imgs, normalize=True
            )

        if self.specific_id_image is not None and self.specific_latent is None:
            align_specific_imgs, specific_transforms, _ = self.run_detect_align(
                self.specific_id_image, for_id=True
            )
            self.specific_latent: torch.Tensor = self.face_id_net(
                align_specific_imgs, normalize=False
            )

        # for_id=False, because we want to get all faces
        align_att_imgs, att_transforms, att_detection_score = self.run_detect_align(
            att_image, for_id=False
        )

        if align_att_imgs is None and att_transforms is None:
            return att_image

        # Select specific crop from the target image
        if self.specific_latent is not None:
            att_latent: torch.Tensor = self.face_id_net(align_att_imgs, normalize=False)
            latent_dist = torch.mean(
                F.mse_loss(
                    att_latent,
                    self.specific_latent.repeat(att_latent.shape[0], 1),
                    reduction="none",
                ),
                dim=-1,
            )

            att_detection_score = torch.tensor(
                att_detection_score, device=latent_dist.device
            )

            min_index = torch.argmin(latent_dist * att_detection_score)
            min_value = latent_dist[min_index]

            if min_value < self.specific_latent_match_threshold:
                align_att_imgs = [align_att_imgs[min_index]]
                att_transforms = [att_transforms[min_index]]
            else:
                return att_image

        swapped_img: torch.Tensor = self.simswap_net(align_att_imgs, self.id_latent)

        if self.enhance_output:
            swapped_img = self.gfpgan_net.enhance(swapped_img, weight=0.5)

        # Put all crops/transformations into a batch
        align_att_img_batch_for_parsing_model: torch.Tensor = torch.stack(
            [self.to_tensor_normalize(x) for x in align_att_imgs], dim=0
        )
        align_att_img_batch_for_parsing_model = (
            align_att_img_batch_for_parsing_model.to(self.device)
        )

        att_transforms: torch.Tensor = torch.stack(
            [torch.tensor(x).float() for x in att_transforms], dim=0
        )
        att_transforms = att_transforms.to(self.device, non_blocking=True)

        align_att_img_batch: torch.Tensor = torch.stack(
            [self.to_tensor(x) for x in align_att_imgs], dim=0
        )
        align_att_img_batch = align_att_img_batch.to(self.device, non_blocking=True)

        # Get face masks for the attribute image
        face_mask, ignore_mask_ids = self.bise_net.get_mask(
            align_att_img_batch_for_parsing_model, self.crop_size
        )

        inv_att_transforms: torch.Tensor = inverse_transform_batch(att_transforms)

        soft_face_mask, _ = self.smooth_mask(face_mask)

        swapped_img[ignore_mask_ids, ...] = align_att_img_batch[ignore_mask_ids, ...]

        frame_size = (att_image.shape[0], att_image.shape[1])

        att_image = self.to_tensor(att_image).to(self.device, non_blocking=True).unsqueeze(0)

        target_image = kornia.geometry.transform.warp_affine(
            swapped_img,
            inv_att_transforms,
            frame_size,
            mode="bilinear",
            padding_mode="border",
            align_corners=True,
            fill_value=torch.zeros(3),
        )

        soft_face_mask = kornia.geometry.transform.warp_affine(
            soft_face_mask,
            inv_att_transforms,
            frame_size,
            mode="bilinear",
            padding_mode="zeros",
            align_corners=True,
            fill_value=torch.zeros(3),
        )

        result = self.blend(target_image, soft_face_mask, att_image)

        return tensor2img(result)