"""
Baseline: based on most-common answer
"""

import pandas as pd
import numpy as np
from tqdm import tqdm
from .metrics import mapk, rank_biased_overlap
from .plots import plot_ranks
import logging
from typing import List, Callable, Optional
from rouge_score import rouge_scorer as rs
from collections import Counter
import random

logger = logging.getLogger(__name__)
tol = 0.001


class MCARank:
    """
    Baseline method: based on most common answer
    """

    def __init__(
        self,
        MODELS: List,
        evaluator: Callable,
        true_ranking: Optional[List] = None,
        show_progress: Optional[bool] = False,
    ):
        self.MODELS = MODELS
        self.N = len(MODELS)
        self.evaluate = evaluator
        self.true_ranking = true_ranking
        self.show_progress = show_progress


    def fit(self, df: pd.DataFrame, measure: Optional[str]='equality', p: float = 0):
        """
        df: Dataframe where each row is a benchmark instance,
        and there is a column with the output for each Model

        measure: decides how the most common answer is decided.
        p - is the noise level to include (only used for noisy-equality)
        """

        assert set(self.MODELS) == set(df.columns), "Benchmark data models inconsistent with models to be ranked."

        if measure == 'equality':

            # Select the most common answer per question
            mca = df.mode(axis=1).iloc[:, 0]

            # Count all the times each model answered the most common one
            wins = df.eq(mca, axis=0).astype(int)

            self.ranking = wins.sum().sort_values(ascending=False).index.to_list()
        
        elif measure == 'noisy_equality':

            # Most common answer
            mca = df.mode(axis=1).iloc[:, 0]

            perturb = lambda x: not x if (random.random() <= p) else x

            def __noisy_equality(x, mca):
                wins = (x == mca).apply(perturb)
                return wins
            
            wins = df.apply(__noisy_equality, axis='rows', args=(mca, ))

            self.ranking = wins.sum().sort_values(ascending=False).index.to_list()

        elif measure == 'rouge':

            MODELS = df.columns.to_list()
            SIZE = 256

            def __mca(x):
                """ Most Commmon Answer, as the top k bigrams across all outputs """

                cs = [rs._create_ngrams(x[m], n=2) for m in MODELS]
                c = sum(cs, Counter())
                return Counter(dict(c.most_common(SIZE)))

            def __score_mca(x):
                """ Rouge score computed relative to most-common-answer """

                res = {}
                for m in MODELS:
                    p_n = rs._create_ngrams(x[m], n=2)
                    res[m] = rs._score_ngrams(x.mca, p_n).fmeasure
                return pd.Series(res)
                
            df['mca'] = df.apply(__mca, axis=1)

            # Winning model based on best ROUGE score for each question
            win_rates = df.apply(__score_mca, axis=1).idxmax(axis=1).value_counts()
            win_rate_rank = win_rates.index.tolist()

            # include models with nowins at the bottom
            no_wins = list(set(MODELS) - set(win_rate_rank))

            self.ranking = win_rate_rank + no_wins
        
        
        else:
            raise ValueError(f"Measure {measure} not understood.")


        logger.info(f"Estimated ranks (best to worst): {self.ranking}")
        logger.info(f"True ranking: {self.true_ranking}")
        logger.info(f"RBO measure: {self.measure()}")
        return self.ranking # Best to worst


    def measure(self, metric='rbo', k=5, p=0.95) -> float:
        """
        Report metric related to self-rank
        """
        if metric not in ['rbo', 'mapk']:
            raise ValueError(f"Metric {metric} not supported (use 'rbo'/'mapk').")

        if hasattr(self, 'ranking'):
            if self.true_ranking is not None:
                if metric == 'mapk':
                    if k > len(self.true_ranking):
                        logger.warning(f"MAPk metric is for k={len(self.true_ranking)}, and not k={k}.")
                    actual = [self.true_ranking[:k]]
                    pred = [self.ranking[:k]]
                    return mapk(actual, pred, k=k)
                elif metric == 'rbo':
                    return rank_biased_overlap(self.true_ranking, self.ranking, p=p)
                else:
                    raise ValueError(f"Metric {metric} not understood.")
            else:
                raise ValueError("True ranking not available for metric calculation.")
        else:
            raise ValueError("Ranking not estimated. Run 'fit' first.")


    def plot(self, caselabel="output"):
        if hasattr(self, 'ranking') & (self.true_ranking is not None):
            plot_ranks(self.true_ranking, self.ranking, "actual", "estimated", caselabel)