import random

import cv2
import numpy as np
from albumentations import DualTransform, ImageOnlyTransform
from albumentations.augmentations import functional as F
from albumentations.core.serialization import SERIALIZABLE_REGISTRY
from albumentations.core.transforms_interface import to_tuple

from isegm.utils.misc import (clamp_bbox, expand_bbox, get_bbox_from_mask,
                              get_labels_with_sizes)


class UniformRandomResize(DualTransform):
    def __init__(
        self,
        scale_range=(0.9, 1.1),
        interpolation=cv2.INTER_LINEAR,
        always_apply=False,
        p=1,
    ):
        super().__init__(always_apply, p)
        self.scale_range = scale_range
        self.interpolation = interpolation

    def get_params_dependent_on_targets(self, params):
        scale = random.uniform(*self.scale_range)
        height = int(round(params["image"].shape[0] * scale))
        width = int(round(params["image"].shape[1] * scale))
        return {"new_height": height, "new_width": width}

    def apply(
        self, img, new_height=0, new_width=0, interpolation=cv2.INTER_LINEAR, **params
    ):
        return F.resize(
            img, height=new_height, width=new_width, interpolation=interpolation
        )

    def apply_to_keypoint(self, keypoint, new_height=0, new_width=0, **params):
        scale_x = new_width / params["cols"]
        scale_y = new_height / params["rows"]
        return F.keypoint_scale(keypoint, scale_x, scale_y)

    def get_transform_init_args_names(self):
        return "scale_range", "interpolation"

    @property
    def targets_as_params(self):
        return ["image"]


class ZoomIn(DualTransform):
    def __init__(
        self,
        height,
        width,
        bbox_jitter=0.1,
        expansion_ratio=1.4,
        min_crop_size=200,
        min_area=100,
        always_resize=False,
        always_apply=False,
        p=0.5,
    ):
        super(ZoomIn, self).__init__(always_apply, p)
        self.height = height
        self.width = width
        self.bbox_jitter = to_tuple(bbox_jitter)
        self.expansion_ratio = expansion_ratio
        self.min_crop_size = min_crop_size
        self.min_area = min_area
        self.always_resize = always_resize

    def apply(self, img, selected_object, bbox, **params):
        if selected_object is None:
            if self.always_resize:
                img = F.resize(img, height=self.height, width=self.width)
            return img

        rmin, rmax, cmin, cmax = bbox
        img = img[rmin : rmax + 1, cmin : cmax + 1]
        img = F.resize(img, height=self.height, width=self.width)

        return img

    def apply_to_mask(self, mask, selected_object, bbox, **params):
        if selected_object is None:
            if self.always_resize:
                mask = F.resize(
                    mask,
                    height=self.height,
                    width=self.width,
                    interpolation=cv2.INTER_NEAREST,
                )
            return mask

        rmin, rmax, cmin, cmax = bbox
        mask = mask[rmin : rmax + 1, cmin : cmax + 1]
        if isinstance(selected_object, tuple):
            layer_indx, mask_id = selected_object
            obj_mask = mask[:, :, layer_indx] == mask_id
            new_mask = np.zeros_like(mask)
            new_mask[:, :, layer_indx][obj_mask] = mask_id
        else:
            obj_mask = mask == selected_object
            new_mask = mask.copy()
            new_mask[np.logical_not(obj_mask)] = 0

        new_mask = F.resize(
            new_mask,
            height=self.height,
            width=self.width,
            interpolation=cv2.INTER_NEAREST,
        )
        return new_mask

    def get_params_dependent_on_targets(self, params):
        instances = params["mask"]

        is_mask_layer = len(instances.shape) > 2
        candidates = []
        if is_mask_layer:
            for layer_indx in range(instances.shape[2]):
                labels, areas = get_labels_with_sizes(instances[:, :, layer_indx])
                candidates.extend(
                    [
                        (layer_indx, obj_id)
                        for obj_id, area in zip(labels, areas)
                        if area > self.min_area
                    ]
                )
        else:
            labels, areas = get_labels_with_sizes(instances)
            candidates = [
                obj_id for obj_id, area in zip(labels, areas) if area > self.min_area
            ]

        selected_object = None
        bbox = None
        if candidates:
            selected_object = random.choice(candidates)
            if is_mask_layer:
                layer_indx, mask_id = selected_object
                obj_mask = instances[:, :, layer_indx] == mask_id
            else:
                obj_mask = instances == selected_object

            bbox = get_bbox_from_mask(obj_mask)

            if isinstance(self.expansion_ratio, tuple):
                expansion_ratio = random.uniform(*self.expansion_ratio)
            else:
                expansion_ratio = self.expansion_ratio

            bbox = expand_bbox(bbox, expansion_ratio, self.min_crop_size)
            bbox = self._jitter_bbox(bbox)
            bbox = clamp_bbox(bbox, 0, obj_mask.shape[0] - 1, 0, obj_mask.shape[1] - 1)

        return {"selected_object": selected_object, "bbox": bbox}

    def _jitter_bbox(self, bbox):
        rmin, rmax, cmin, cmax = bbox
        height = rmax - rmin + 1
        width = cmax - cmin + 1
        rmin = int(rmin + random.uniform(*self.bbox_jitter) * height)
        rmax = int(rmax + random.uniform(*self.bbox_jitter) * height)
        cmin = int(cmin + random.uniform(*self.bbox_jitter) * width)
        cmax = int(cmax + random.uniform(*self.bbox_jitter) * width)

        return rmin, rmax, cmin, cmax

    def apply_to_bbox(self, bbox, **params):
        raise NotImplementedError

    def apply_to_keypoint(self, keypoint, **params):
        raise NotImplementedError

    @property
    def targets_as_params(self):
        return ["mask"]

    def get_transform_init_args_names(self):
        return (
            "height",
            "width",
            "bbox_jitter",
            "expansion_ratio",
            "min_crop_size",
            "min_area",
            "always_resize",
        )


def remove_image_only_transforms(sdict):
    if not "transforms" in sdict:
        return sdict

    keep_transforms = []
    for tdict in sdict["transforms"]:
        cls = SERIALIZABLE_REGISTRY[tdict["__class_fullname__"]]
        if "transforms" in tdict:
            keep_transforms.append(remove_image_only_transforms(tdict))
        elif not issubclass(cls, ImageOnlyTransform):
            keep_transforms.append(tdict)
    sdict["transforms"] = keep_transforms

    return sdict