import numpy as np
import torch
import torch.nn.functional as F
from scipy.optimize import fmin_l_bfgs_b

from .base import BasePredictor


class BRSBasePredictor(BasePredictor):
    def __init__(self, model, device, opt_functor, optimize_after_n_clicks=1, **kwargs):
        super().__init__(model, device, **kwargs)
        self.optimize_after_n_clicks = optimize_after_n_clicks
        self.opt_functor = opt_functor

        self.opt_data = None
        self.input_data = None

    def set_input_image(self, image):
        super().set_input_image(image)
        self.opt_data = None
        self.input_data = None

    def _get_clicks_maps_nd(self, clicks_lists, image_shape, radius=1):
        pos_clicks_map = np.zeros(
            (len(clicks_lists), 1) + image_shape, dtype=np.float32
        )
        neg_clicks_map = np.zeros(
            (len(clicks_lists), 1) + image_shape, dtype=np.float32
        )

        for list_indx, clicks_list in enumerate(clicks_lists):
            for click in clicks_list:
                y, x = click.coords
                y, x = int(round(y)), int(round(x))
                y1, x1 = y - radius, x - radius
                y2, x2 = y + radius + 1, x + radius + 1

                if click.is_positive:
                    pos_clicks_map[list_indx, 0, y1:y2, x1:x2] = True
                else:
                    neg_clicks_map[list_indx, 0, y1:y2, x1:x2] = True

        with torch.no_grad():
            pos_clicks_map = torch.from_numpy(pos_clicks_map).to(self.device)
            neg_clicks_map = torch.from_numpy(neg_clicks_map).to(self.device)

        return pos_clicks_map, neg_clicks_map

    def get_states(self):
        return {
            "transform_states": self._get_transform_states(),
            "opt_data": self.opt_data,
        }

    def set_states(self, states):
        self._set_transform_states(states["transform_states"])
        self.opt_data = states["opt_data"]


class FeatureBRSPredictor(BRSBasePredictor):
    def __init__(
        self, model, device, opt_functor, insertion_mode="after_deeplab", **kwargs
    ):
        super().__init__(model, device, opt_functor=opt_functor, **kwargs)
        self.insertion_mode = insertion_mode
        self._c1_features = None

        if self.insertion_mode == "after_deeplab":
            self.num_channels = model.feature_extractor.ch
        elif self.insertion_mode == "after_c4":
            self.num_channels = model.feature_extractor.aspp_in_channels
        elif self.insertion_mode == "after_aspp":
            self.num_channels = model.feature_extractor.ch + 32
        else:
            raise NotImplementedError

    def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
        points_nd = self.get_points_nd(clicks_lists)
        pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:])

        num_clicks = len(clicks_lists[0])
        bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]

        if (
            self.opt_data is None
            or self.opt_data.shape[0] // (2 * self.num_channels) != bs
        ):
            self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32)

        if (
            num_clicks <= self.net_clicks_limit
            or is_image_changed
            or self.input_data is None
        ):
            self.input_data = self._get_head_input(image_nd, points_nd)

        def get_prediction_logits(scale, bias):
            scale = scale.view(bs, -1, 1, 1)
            bias = bias.view(bs, -1, 1, 1)
            if self.with_flip:
                scale = scale.repeat(2, 1, 1, 1)
                bias = bias.repeat(2, 1, 1, 1)

            scaled_backbone_features = self.input_data * scale
            scaled_backbone_features = scaled_backbone_features + bias
            if self.insertion_mode == "after_c4":
                x = self.net.feature_extractor.aspp(scaled_backbone_features)
                x = F.interpolate(
                    x,
                    mode="bilinear",
                    size=self._c1_features.size()[2:],
                    align_corners=True,
                )
                x = torch.cat((x, self._c1_features), dim=1)
                scaled_backbone_features = self.net.feature_extractor.head(x)
            elif self.insertion_mode == "after_aspp":
                scaled_backbone_features = self.net.feature_extractor.head(
                    scaled_backbone_features
                )

            pred_logits = self.net.head(scaled_backbone_features)
            pred_logits = F.interpolate(
                pred_logits,
                size=image_nd.size()[2:],
                mode="bilinear",
                align_corners=True,
            )
            return pred_logits

        self.opt_functor.init_click(
            get_prediction_logits, pos_mask, neg_mask, self.device
        )
        if num_clicks > self.optimize_after_n_clicks:
            opt_result = fmin_l_bfgs_b(
                func=self.opt_functor,
                x0=self.opt_data,
                **self.opt_functor.optimizer_params
            )
            self.opt_data = opt_result[0]

        with torch.no_grad():
            if self.opt_functor.best_prediction is not None:
                opt_pred_logits = self.opt_functor.best_prediction
            else:
                opt_data_nd = torch.from_numpy(self.opt_data).to(self.device)
                opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd)
                opt_pred_logits = get_prediction_logits(*opt_vars)

        return opt_pred_logits

    def _get_head_input(self, image_nd, points):
        with torch.no_grad():
            image_nd, prev_mask = self.net.prepare_input(image_nd)
            coord_features = self.net.get_coord_features(image_nd, prev_mask, points)

            if self.net.rgb_conv is not None:
                x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1))
                additional_features = None
            elif hasattr(self.net, "maps_transform"):
                x = image_nd
                additional_features = self.net.maps_transform(coord_features)

            if self.insertion_mode == "after_c4" or self.insertion_mode == "after_aspp":
                c1, _, c3, c4 = self.net.feature_extractor.backbone(
                    x, additional_features
                )
                c1 = self.net.feature_extractor.skip_project(c1)

                if self.insertion_mode == "after_aspp":
                    x = self.net.feature_extractor.aspp(c4)
                    x = F.interpolate(
                        x, size=c1.size()[2:], mode="bilinear", align_corners=True
                    )
                    x = torch.cat((x, c1), dim=1)
                    backbone_features = x
                else:
                    backbone_features = c4
                    self._c1_features = c1
            else:
                backbone_features = self.net.feature_extractor(x, additional_features)[
                    0
                ]

        return backbone_features


class HRNetFeatureBRSPredictor(BRSBasePredictor):
    def __init__(self, model, device, opt_functor, insertion_mode="A", **kwargs):
        super().__init__(model, device, opt_functor=opt_functor, **kwargs)
        self.insertion_mode = insertion_mode
        self._c1_features = None

        if self.insertion_mode == "A":
            self.num_channels = sum(
                k * model.feature_extractor.width for k in [1, 2, 4, 8]
            )
        elif self.insertion_mode == "C":
            self.num_channels = 2 * model.feature_extractor.ocr_width
        else:
            raise NotImplementedError

    def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
        points_nd = self.get_points_nd(clicks_lists)
        pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:])
        num_clicks = len(clicks_lists[0])
        bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]

        if (
            self.opt_data is None
            or self.opt_data.shape[0] // (2 * self.num_channels) != bs
        ):
            self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32)

        if (
            num_clicks <= self.net_clicks_limit
            or is_image_changed
            or self.input_data is None
        ):
            self.input_data = self._get_head_input(image_nd, points_nd)

        def get_prediction_logits(scale, bias):
            scale = scale.view(bs, -1, 1, 1)
            bias = bias.view(bs, -1, 1, 1)
            if self.with_flip:
                scale = scale.repeat(2, 1, 1, 1)
                bias = bias.repeat(2, 1, 1, 1)

            scaled_backbone_features = self.input_data * scale
            scaled_backbone_features = scaled_backbone_features + bias
            if self.insertion_mode == "A":
                if self.net.feature_extractor.ocr_width > 0:
                    out_aux = self.net.feature_extractor.aux_head(
                        scaled_backbone_features
                    )
                    feats = self.net.feature_extractor.conv3x3_ocr(
                        scaled_backbone_features
                    )

                    context = self.net.feature_extractor.ocr_gather_head(feats, out_aux)
                    feats = self.net.feature_extractor.ocr_distri_head(feats, context)
                else:
                    feats = scaled_backbone_features
                pred_logits = self.net.feature_extractor.cls_head(feats)
            elif self.insertion_mode == "C":
                pred_logits = self.net.feature_extractor.cls_head(
                    scaled_backbone_features
                )
            else:
                raise NotImplementedError

            pred_logits = F.interpolate(
                pred_logits,
                size=image_nd.size()[2:],
                mode="bilinear",
                align_corners=True,
            )
            return pred_logits

        self.opt_functor.init_click(
            get_prediction_logits, pos_mask, neg_mask, self.device
        )
        if num_clicks > self.optimize_after_n_clicks:
            opt_result = fmin_l_bfgs_b(
                func=self.opt_functor,
                x0=self.opt_data,
                **self.opt_functor.optimizer_params
            )
            self.opt_data = opt_result[0]

        with torch.no_grad():
            if self.opt_functor.best_prediction is not None:
                opt_pred_logits = self.opt_functor.best_prediction
            else:
                opt_data_nd = torch.from_numpy(self.opt_data).to(self.device)
                opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd)
                opt_pred_logits = get_prediction_logits(*opt_vars)

        return opt_pred_logits

    def _get_head_input(self, image_nd, points):
        with torch.no_grad():
            image_nd, prev_mask = self.net.prepare_input(image_nd)
            coord_features = self.net.get_coord_features(image_nd, prev_mask, points)

            if self.net.rgb_conv is not None:
                x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1))
                additional_features = None
            elif hasattr(self.net, "maps_transform"):
                x = image_nd
                additional_features = self.net.maps_transform(coord_features)

            feats = self.net.feature_extractor.compute_hrnet_feats(
                x, additional_features
            )

            if self.insertion_mode == "A":
                backbone_features = feats
            elif self.insertion_mode == "C":
                out_aux = self.net.feature_extractor.aux_head(feats)
                feats = self.net.feature_extractor.conv3x3_ocr(feats)

                context = self.net.feature_extractor.ocr_gather_head(feats, out_aux)
                backbone_features = self.net.feature_extractor.ocr_distri_head(
                    feats, context
                )
            else:
                raise NotImplementedError

        return backbone_features


class InputBRSPredictor(BRSBasePredictor):
    def __init__(self, model, device, opt_functor, optimize_target="rgb", **kwargs):
        super().__init__(model, device, opt_functor=opt_functor, **kwargs)
        self.optimize_target = optimize_target

    def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
        points_nd = self.get_points_nd(clicks_lists)
        pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:])
        num_clicks = len(clicks_lists[0])

        if self.opt_data is None or is_image_changed:
            if self.optimize_target == "dmaps":
                opt_channels = (
                    self.net.coord_feature_ch - 1
                    if self.net.with_prev_mask
                    else self.net.coord_feature_ch
                )
            else:
                opt_channels = 3
            bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
            self.opt_data = torch.zeros(
                (bs, opt_channels, image_nd.shape[2], image_nd.shape[3]),
                device=self.device,
                dtype=torch.float32,
            )

        def get_prediction_logits(opt_bias):
            input_image, prev_mask = self.net.prepare_input(image_nd)
            dmaps = self.net.get_coord_features(input_image, prev_mask, points_nd)

            if self.optimize_target == "rgb":
                input_image = input_image + opt_bias
            elif self.optimize_target == "dmaps":
                if self.net.with_prev_mask:
                    dmaps[:, 1:, :, :] = dmaps[:, 1:, :, :] + opt_bias
                else:
                    dmaps = dmaps + opt_bias

            if self.net.rgb_conv is not None:
                x = self.net.rgb_conv(torch.cat((input_image, dmaps), dim=1))
                if self.optimize_target == "all":
                    x = x + opt_bias
                coord_features = None
            elif hasattr(self.net, "maps_transform"):
                x = input_image
                coord_features = self.net.maps_transform(dmaps)

            pred_logits = self.net.backbone_forward(x, coord_features=coord_features)[
                "instances"
            ]
            pred_logits = F.interpolate(
                pred_logits,
                size=image_nd.size()[2:],
                mode="bilinear",
                align_corners=True,
            )

            return pred_logits

        self.opt_functor.init_click(
            get_prediction_logits,
            pos_mask,
            neg_mask,
            self.device,
            shape=self.opt_data.shape,
        )
        if num_clicks > self.optimize_after_n_clicks:
            opt_result = fmin_l_bfgs_b(
                func=self.opt_functor,
                x0=self.opt_data.cpu().numpy().ravel(),
                **self.opt_functor.optimizer_params
            )

            self.opt_data = (
                torch.from_numpy(opt_result[0])
                .view(self.opt_data.shape)
                .to(self.device)
            )

        with torch.no_grad():
            if self.opt_functor.best_prediction is not None:
                opt_pred_logits = self.opt_functor.best_prediction
            else:
                opt_vars, _ = self.opt_functor.unpack_opt_params(self.opt_data)
                opt_pred_logits = get_prediction_logits(*opt_vars)

        return opt_pred_logits