import math

import torch
import transformers
from transformers import LogitsWarper, is_torch_xpu_available
from transformers.generation.logits_process import (
    LogitNormalization,
    LogitsProcessor,
    LogitsProcessorList,
    TemperatureLogitsWarper
)

from modules import shared

global_scores = None


class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
    def __init__(self, temperature: float, dynamic_temperature: bool, dynatemp_low: float, dynatemp_high: float, dynatemp_exponent: float):
        if not isinstance(temperature, float) or not (temperature > 0):
            except_msg = (
                f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
                "scores will be invalid."
            )
            if isinstance(temperature, float) and temperature == 0.0:
                except_msg += " If you're looking for greedy decoding strategies, set `do_sample=False`."

            raise ValueError(except_msg)

        self.temperature = temperature
        self.dynamic_temperature = dynamic_temperature
        self.dynatemp_low = dynatemp_low
        self.dynatemp_high = dynatemp_high
        self.dynatemp_exponent = dynatemp_exponent

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:

        # Regular temperature
        if not self.dynamic_temperature:
            scores = scores / self.temperature
            return scores

        # Dynamic temperature
        else:
            min_temp = self.dynatemp_low
            max_temp = self.dynatemp_high
            exponent_val = self.dynatemp_exponent

            # Convert logits to probabilities
            probs = torch.softmax(scores, dim=-1)

            # Calculate entropy of the softmax probabilities
            entropy = -1.0 * torch.where(probs > 0, probs * torch.log(probs), torch.zeros_like(probs)).sum()

            # Guard against future possible division by zero
            entropy = max(entropy, torch.tensor(1e-10))  # Ensures entropy is slightly greater than 0

            # Any logits which are not -Infinity will be considered for calculating max entropy.
            num_valid_tokens = torch.sum(scores > -float('inf')).item()

            # Now, calculate the max entropy by using only the valid tokens' count
            max_entropy = math.log(num_valid_tokens)

            # Guard against future possible division by zero
            max_entropy = max_entropy if max_entropy > 0.0 else 1e-10

            # Normalize the entropy
            normalized_entropy = entropy / max_entropy

            # Map the normalized entropy to the desired temperature range using the power function
            dyn_temp = min_temp + (max_temp - min_temp) * (normalized_entropy.pow(exponent_val))

            # Apply the dynamically calculated temperature scaling
            scores = scores / dyn_temp

            # print("----------------------\nTemperature from generation_config:", self.temperature)
            # print("min_temp:", min_temp)
            # print("max_temp:", max_temp)
            # print("Entropy:", entropy.item())
            # print("Max Possible Entropy considering valid tokens only:", max_entropy)
            # print("Normalized Entropy:", normalized_entropy.item())
            # print("Dynamic Temperature (dyn_temp):", dyn_temp.item())
            # print("----------------------")

            # max_prob_token_id = torch.argmax(scores, dim=-1)  # Get the token ID with the highest probability
            # max_prob_token = shared.tokenizer.convert_ids_to_tokens(int(max_prob_token_id))  # Convert ID to token
            # print("--- T=", float(dyn_temp), "token=", max_prob_token, "min=", min_temp, "max=", max_temp, "exponent=", exponent_val)

            return scores


class MinPLogitsWarper(LogitsWarper):
    def __init__(self, min_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
        if min_p < 0 or min_p > 1.0:
            raise ValueError(f"`min_p` has to be a float >= 0 and <= 1, but is {min_p}")
        self.min_p = min_p
        self.filter_value = filter_value
        self.min_tokens_to_keep = min_tokens_to_keep

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # Convert logits to probabilities
        probs = torch.softmax(scores, dim=-1)
        # Get the probability of the top token for each sequence in the batch
        top_probs, _ = probs.max(dim=-1, keepdim=True)
        # Calculate the actual min_p threshold by scaling min_p with the top token's probability
        scaled_min_p = self.min_p * top_probs
        # Create a mask for tokens that have a probability less than the scaled min_p
        tokens_to_remove = probs < scaled_min_p

        sorted_indices = torch.argsort(scores, descending=True, dim=-1)
        sorted_indices_to_remove = torch.gather(tokens_to_remove, dim=-1, index=sorted_indices)

        if self.min_tokens_to_keep > 1:
            # Keep at least min_tokens_to_keep
            sorted_indices_to_remove[..., : self.min_tokens_to_keep] = False

        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        scores = scores.masked_fill(indices_to_remove, self.filter_value)
        return scores


class TailFreeLogitsWarper(LogitsWarper):
    def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
        tfs = float(tfs)
        if tfs < 0 or tfs > 1.0:
            raise ValueError(f"`tfs` has to be a float >= 0 and <= 1, but is {tfs}")
        self.tfs = tfs
        self.filter_value = filter_value
        self.min_tokens_to_keep = min_tokens_to_keep

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        sorted_logits, sorted_indices = torch.sort(scores, descending=True)
        probs = sorted_logits.softmax(dim=-1)

        # Compute second derivative normalized CDF
        d2 = probs.diff().diff().abs()
        normalized_d2 = d2 / d2.sum(dim=-1, keepdim=True)
        normalized_d2_cdf = normalized_d2.cumsum(dim=-1)

        # Remove tokens with CDF value above the threshold (token with 0 are kept)
        sorted_indices_to_remove = normalized_d2_cdf > self.tfs

        # Centre the distribution around the cutoff as in the original implementation of the algorithm
        sorted_indices_to_remove = torch.cat(
            (
                torch.zeros(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
                sorted_indices_to_remove,
                torch.ones(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
            ),
            dim=-1,
        )

        if self.min_tokens_to_keep > 1:
            # Keep at least min_tokens_to_keep
            sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0

        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        scores = scores.masked_fill(indices_to_remove, self.filter_value)
        return scores


class TopALogitsWarper(LogitsWarper):
    def __init__(self, top_a: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
        top_a = float(top_a)
        if top_a < 0 or top_a > 1.0:
            raise ValueError(f"`top_a` has to be a float >= 0 and <= 1, but is {top_a}")
        self.top_a = top_a
        self.filter_value = filter_value
        self.min_tokens_to_keep = min_tokens_to_keep

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        sorted_logits, sorted_indices = torch.sort(scores, descending=True)
        probs = sorted_logits.softmax(dim=-1)

        # Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept)
        probs_max = probs[..., 0, None]
        sorted_indices_to_remove = probs < probs_max * probs_max * self.top_a

        if self.min_tokens_to_keep > 1:
            # Keep at least min_tokens_to_keep
            sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0

        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        scores = scores.masked_fill(indices_to_remove, self.filter_value)
        return scores


class MirostatLogitsWarper(LogitsWarper):
    def __init__(self, mirostat_mode: int, mirostat_tau: float, mirostat_eta: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
        if mirostat_mode not in [2]:
            raise ValueError(f"`mirostat` has to be a an integer 2, but is {mirostat_mode}")
        self.mirostat_mode = mirostat_mode
        self.mirostat_eta = mirostat_eta
        self.mirostat_tau = mirostat_tau
        self.filter_value = filter_value
        self.min_tokens_to_keep = min_tokens_to_keep
        self.mu = 2 * self.mirostat_tau
        self.e = 0

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        logits = scores[0]
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        prob_original = torch.softmax(sorted_logits, dim=-1).tolist()  # candidates

        # Truncate the words with surprise values greater than mu
        for i, candidate in enumerate(prob_original):
            if candidate > 0 and -math.log2(candidate) > self.mu:
                if (i == 0):
                    sorted_logits = sorted_logits[:1]
                else:
                    sorted_logits = sorted_logits[:i]
                break

        # Normalize the probabilities of the remaining words
        if is_torch_xpu_available():
            prob_topk = torch.softmax(sorted_logits, dim=0).to("xpu")
            prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True).to("xpu")
        else:
            prob_topk = torch.softmax(sorted_logits, dim=0).to('cuda')
            prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True).to('cuda')

        observed_surprise = -math.log2(prob_topk[prev_i])
        self.e = observed_surprise - self.mirostat_tau

        # Update mu using the learning rate and error
        self.mu -= self.mirostat_eta * self.e

        sorted_indices_to_remove = torch.ones_like(scores[0], dtype=torch.bool)
        sorted_indices_to_remove[prev_i] = False

        indices_to_remove = sorted_indices_to_remove.unsqueeze(0).scatter(1, sorted_indices.unsqueeze(0), sorted_indices_to_remove.unsqueeze(0))
        scores = scores.masked_fill(indices_to_remove, self.filter_value)
        return scores


class SpyLogitsWarper(LogitsWarper):
    def __init__(self):
        pass

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        global global_scores
        global_scores = scores
        return scores


class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
    '''
    Copied from the transformers library
    '''

    def __init__(self, penalty: float, presence_penalty: float, frequency_penalty: float, _range: int):
        if not (penalty > 0):
            raise ValueError(f"`penalty` has to be strictly positive, but is {penalty}")

        self.penalty = penalty
        self.presence_penalty = presence_penalty
        self.frequency_penalty = frequency_penalty
        self._range = _range

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        input_ids = input_ids[:, -self._range:]

        # We loop here because torch.unique() needs to process each row separately in the
        # case that batch_size > 1.
        for input_ids_row, scores_row in zip(input_ids, scores):
            unique_ids, counts = torch.unique(input_ids_row, return_counts=True)
            score = torch.gather(scores_row, 0, unique_ids)

            # multiplicative repetition penalty
            # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
            score = torch.where(score < 0, score * self.penalty, score / self.penalty)
            scores_row.scatter_(0, unique_ids, score)

            # presence_penalty and frequency_penalty
            raw_presence_penalty = (counts > 0).to(scores.dtype)
            raw_frequency_penalty = counts.to(scores.dtype)
            additive_penalty = raw_presence_penalty * self.presence_penalty + raw_frequency_penalty * self.frequency_penalty
            scores_row.scatter_add_(0, unique_ids, -additive_penalty)

        return scores


def get_logits_warper_patch(self, generation_config):
    # Make sure that temperature is float and not int
    if isinstance(generation_config.temperature, int):
        generation_config.temperature = float(generation_config.temperature)

    temperature = generation_config.temperature
    if generation_config.dynamic_temperature:
        # Make sure TemperatureLogitsWarper will be created by temporarily
        # setting temperature to a value != 1.
        generation_config.temperature = 1.1

    warpers = self._get_logits_warper_old(generation_config)
    for i in range(len(warpers)):
        if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
            warpers[i] = TemperatureLogitsWarperWithDynatemp(
                temperature,
                generation_config.dynamic_temperature,
                generation_config.dynatemp_low,
                generation_config.dynatemp_high,
                generation_config.dynatemp_exponent
            )

    warpers_to_add = LogitsProcessorList()
    min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1

    if generation_config.mirostat_mode is not None and generation_config.mirostat_mode == 2:
        warpers_to_add.append(MirostatLogitsWarper(mirostat_mode=generation_config.mirostat_mode, mirostat_eta=generation_config.mirostat_eta, mirostat_tau=generation_config.mirostat_tau, min_tokens_to_keep=min_tokens_to_keep))
        # We need to disable samplers other than temperature
        for warper in warpers:
            if not isinstance(warper, TemperatureLogitsWarper):
                warpers.remove(warper)
    else:
        if generation_config.tfs is not None and 0.0 <= generation_config.tfs < 1.0:
            warpers_to_add.append(TailFreeLogitsWarper(tfs=generation_config.tfs, min_tokens_to_keep=min_tokens_to_keep))
        if generation_config.top_a is not None and 0.0 < generation_config.top_a <= 1.0:
            warpers_to_add.append(TopALogitsWarper(top_a=generation_config.top_a, min_tokens_to_keep=min_tokens_to_keep))
        if generation_config.min_p is not None and 0.0 < generation_config.min_p <= 1.0:
            warpers_to_add.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))

    if len(warpers) > 0 and isinstance(warpers[-1], LogitNormalization):
        normalize = warpers.pop(-1)
    else:
        normalize = None

    warpers += warpers_to_add
    if generation_config.temperature_last:
        temperature_idx = None
        for i in range(len(warpers)):
            if warpers[i].__class__.__name__ in ['TemperatureLogitsWarper', 'TemperatureLogitsWarperWithDynatemp']:
                temperature_idx = i
                break

        if temperature_idx is not None:
            warpers.append(warpers.pop(temperature_idx))

    if normalize is not None:
        warpers.append(normalize)

    warpers.append(SpyLogitsWarper())
    warpers = LogitsProcessorList(warpers)
    # for i in range(len(warpers)):
    #     print(warpers[i].__class__.__name__)
    return warpers


def get_logits_processor_patch(self, **kwargs):
    repetition_penalty = kwargs['generation_config'].repetition_penalty
    presence_penalty = kwargs['generation_config'].presence_penalty
    frequency_penalty = kwargs['generation_config'].frequency_penalty
    repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range
    do_rep_pen_hijack = (repetition_penalty > 1) or (presence_penalty != 0) or (frequency_penalty != 0)
    if do_rep_pen_hijack:
        # Make sure that a RepetitionPenaltyLogitsProcessor will be created
        kwargs['generation_config'].repetition_penalty = 1.1  # must set to some value > 1

    result = self._get_logits_processor_old(**kwargs)

    if do_rep_pen_hijack:
        for i in range(len(result)):
            if result[i].__class__.__name__ == 'RepetitionPenaltyLogitsProcessor':
                result[i] = RepetitionPenaltyLogitsProcessorWithRange(repetition_penalty, presence_penalty, frequency_penalty, repetition_penalty_range)

    return result


def generation_config_init_patch(self, **kwargs):
    self.__init___old(**kwargs)
    self.min_p = kwargs.pop("min_p", 0.0)
    self.dynamic_temperature = kwargs.pop("dynamic_temperature", False)
    self.dynatemp_low = kwargs.pop("dynatemp_low", 1)
    self.dynatemp_high = kwargs.pop("dynatemp_high", 1)
    self.dynatemp_exponent = kwargs.pop("dynatemp_exponent", 1)
    self.tfs = kwargs.pop("tfs", 1.0)
    self.top_a = kwargs.pop("top_a", 0.0)
    self.mirostat_mode = kwargs.pop("mirostat_mode", 0)
    self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1)
    self.mirostat_tau = kwargs.pop("mirostat_tau", 5)
    self.repetition_penalty_range = kwargs.pop("repetition_penalty_range", 0)
    self.presence_penalty = kwargs.pop("presence_penalty", 0)
    self.frequency_penalty = kwargs.pop("frequency_penalty", 0)
    self.temperature_last = kwargs.pop("temperature_last", False)


def hijack_samplers():
    transformers.GenerationMixin._get_logits_warper_old = transformers.GenerationMixin._get_logits_warper
    transformers.GenerationMixin._get_logits_warper = get_logits_warper_patch

    transformers.GenerationMixin._get_logits_processor_old = transformers.GenerationMixin._get_logits_processor
    transformers.GenerationMixin._get_logits_processor = get_logits_processor_patch

    transformers.GenerationConfig.__init___old = transformers.GenerationConfig.__init__
    transformers.GenerationConfig.__init__ = generation_config_init_patch