import torch
import torch.nn as nn


def gumbel_sigmoid(logits: torch.Tensor, tau: float = 1, hard: bool = False):
    """Samples from the Gumbel-Sigmoid distribution and optionally discretizes.
    References:
        - https://github.com/yandexdataschool/gumbel_dpg/blob/master/gumbel.py
        - https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax
    Note:
        X - Y ~ Logistic(0,1) s.t. X, Y ~ Gumbel(0, 1).
        That is, we can implement gumbel_sigmoid using Logistic distribution.
    """
    logistic = torch.rand_like(logits)
    logistic = logistic.div_(1. - logistic).log_()  # ~Logistic(0,1)

    gumbels = (logits + logistic) / tau  # ~Logistic(logits, tau)
    y_soft = gumbels.sigmoid_()

    if hard:
        # Straight through.
        y_hard = y_soft.gt(0.5).type(y_soft.dtype)
        # gt_ break gradient flow
        #  y_hard = y_soft.gt_(0.5)  # gt_() maintain dtype, different to gt()
        ret = y_hard - y_soft.detach() + y_soft
    else:
        # Reparametrization trick.
        ret = y_soft

    return ret


class Sim2Mask(nn.Module):
    def __init__(self, init_w: float = 1.0, init_b: float = 0.0, gumbel_tau: float = 1.0, learnable: bool = True):
        """
        Sim2Mask module for generating binary masks.

        Args:
            init_w (float): Initial value for weight.
            init_b (float): Initial value for bias.
            gumbel_tau (float): Gumbel-Softmax temperature.
            learnable (bool): If True, weight and bias are learnable parameters.

        Reference:
            "Learning to Generate Text-grounded Mask for Open-world Semantic Segmentation from Only Image-Text Pairs" CVPR 2023
            - https://github.com/kakaobrain/tcl
            - https://arxiv.org/abs/2212.00785
        """
        super().__init__()
        self.init_w = init_w
        self.init_b = init_b
        self.gumbel_tau = gumbel_tau
        self.learnable = learnable

        assert not ((init_w is None) ^ (init_b is None))
        if learnable:
            self.w = nn.Parameter(torch.full([], float(init_w)))
            self.b = nn.Parameter(torch.full([], float(init_b)))
        else:
            self.w = init_w
            self.b = init_b

    def forward(self, x, deterministic=False):
        logits = x * self.w + self.b

        soft_mask = torch.sigmoid(logits)
        if deterministic:
            hard_mask = soft_mask.gt(0.5).type(logits.dtype)
        else:
            hard_mask = gumbel_sigmoid(logits, hard=True, tau=self.gumbel_tau)

        return hard_mask, soft_mask

    def extra_repr(self):
        return f'init_w={self.init_w}, init_b={self.init_b}, learnable={self.learnable}, gumbel_tau={self.gumbel_tau}'


def norm_img_tensor(tensor: torch.Tensor) -> torch.Tensor:
    """
    Normalize image tensor to the range [0, 1].

    Args:
        tensor (torch.Tensor): Input image tensor.

    Returns:
        torch.Tensor: Normalized image tensor.
    """
    vmin = tensor.amin((2, 3), keepdims=True) - 1e-7
    vmax = tensor.amax((2, 3), keepdims=True) + 1e-7
    tensor = (tensor - vmin) / (vmax - vmin)
    return tensor


class ImageMasker(Sim2Mask):
    def forward(self, x: torch.Tensor, infer: bool = False) -> torch.Tensor:
        """
        Forward pass for generating image-level binary masks.

        Args:
            x (torch.Tensor): Input tensor.
            infer (bool): True for only inference stage.

        Returns:
            torch.Tensor: Binary mask.

        Reference:
            "Can CLIP Help Sound Source Localization?" WACV 2024
            - https://arxiv.org/abs/2311.04066
        """
        if self.training or not infer:
            output = super().forward(x, False)[0]
        else:
            output = torch.sigmoid(x + self.b / self.w)
        return output


class FeatureMasker(nn.Module):
    def __init__(self, thr: float = 0.5, tau: float = 0.07):
        """
        Masker module for generating feature-level masks.

        Args:
            thr (float): Threshold for generating the mask.
            tau (float): Temperature for the sigmoid function.

        Reference:
            "Can CLIP Help Sound Source Localization?" WACV 2024
            - https://arxiv.org/abs/2311.04066
        """
        super().__init__()
        self.thr = thr
        self.tau = tau

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for generating feature-level masks

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Generated mask.
        """
        return torch.sigmoid((norm_img_tensor(x) - self.thr) / self.tau)