import logging

import datasets
import gradio as gr
import pandas as pd
import datetime

from fetch_utils import (check_dataset_and_get_config,
                         check_dataset_and_get_split)

import leaderboard
logger = logging.getLogger(__name__)
global update_time 
update_time = datetime.datetime.fromtimestamp(0)

def get_records_from_dataset_repo(dataset_id):
    dataset_config = check_dataset_and_get_config(dataset_id)

    logger.info(f"Dataset {dataset_id} has configs {dataset_config}")
    dataset_split = check_dataset_and_get_split(dataset_id, dataset_config[0])
    logger.info(f"Dataset {dataset_id} has splits {dataset_split}")

    try:
        ds = datasets.load_dataset(dataset_id, dataset_config[0], split=dataset_split[0])
        df = ds.to_pandas()
        return df
    except Exception as e:
        logger.warning(
            f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}"
        )
        return pd.DataFrame()

    
def get_model_ids(ds):
    logging.info(f"Dataset {ds} column names: {ds['model_id']}")
    models = ds["model_id"].tolist()
    # return unique elements in the list model_ids
    model_ids = list(set(models))
    model_ids.insert(0, "Any")
    return model_ids


def get_dataset_ids(ds):
    logging.info(f"Dataset {ds} column names: {ds['dataset_id']}")
    datasets = ds["dataset_id"].tolist()
    dataset_ids = list(set(datasets))
    dataset_ids.insert(0, "Any")
    return dataset_ids


def get_types(ds):
    # set types for each column
    types = [str(t) for t in ds.dtypes.to_list()]
    types = [t.replace("object", "markdown") for t in types]
    types = [t.replace("float64", "number") for t in types]
    types = [t.replace("int64", "number") for t in types]
    return types


def get_display_df(df):
    # style all elements in the model_id column
    display_df = df.copy()
    columns = display_df.columns.tolist()
    if "model_id" in columns:
        display_df["model_id"] = display_df["model_id"].apply(
            lambda x: f'<a href="https://huggingface.co/{x}" target="_blank" style="color:blue">🔗{x}</a>'
        )
    # style all elements in the dataset_id column
    if "dataset_id" in columns:
        display_df["dataset_id"] = display_df["dataset_id"].apply(
            lambda x: f'<a href="https://huggingface.co/datasets/{x}" target="_blank" style="color:blue">🔗{x}</a>'
        )
    # style all elements in the report_link column
    if "report_link" in columns:
        display_df["report_link"] = display_df["report_link"].apply(
            lambda x: f'<a href="{x}" target="_blank" style="color:blue">🔗{x}</a>'
        )
    return display_df

def get_demo(leaderboard_tab):
    global update_time
    update_time = datetime.datetime.now()
    logger.info("Loading leaderboard records")
    leaderboard.records = get_records_from_dataset_repo(leaderboard.LEADERBOARD)
    records = leaderboard.records

    model_ids = get_model_ids(records)
    dataset_ids = get_dataset_ids(records)

    column_names = records.columns.tolist()
    issue_columns = column_names[:11]
    info_columns = column_names[15:]
    default_columns = ["model_id", "dataset_id", "total_issues", "report_link"]
    default_df = records[default_columns]  # extract columns selected
    types = get_types(default_df)
    display_df = get_display_df(default_df)  # the styled dataframe to display

    with gr.Row():
        with gr.Column():
          info_columns_select = gr.CheckboxGroup(
              label="Info Columns",
              choices=info_columns,
              value=default_columns,
              interactive=True,
        )
        with gr.Column():
          issue_columns_select = gr.CheckboxGroup(
              label="Issue Columns",
              choices=issue_columns,
              value=[],
              interactive=True,
          )
    
    with gr.Row():
        task_select = gr.Dropdown(
            label="Task",
            choices=["text_classification"],
            value="text_classification",
            interactive=True,
        )
        model_select = gr.Dropdown(
            label="Model id", choices=model_ids, value=model_ids[0], interactive=True
        )
        dataset_select = gr.Dropdown(
            label="Dataset id",
            choices=dataset_ids,
            value=dataset_ids[0],
            interactive=True,
        )

    with gr.Row():
        leaderboard_df = gr.DataFrame(display_df, datatype=types, interactive=False)

    def update_leaderboard_records(model_id, dataset_id, issue_columns, info_columns, task):
        global update_time
        if datetime.datetime.now() - update_time < datetime.timedelta(minutes=10):
            return gr.update()
        update_time = datetime.datetime.now()
        logger.info("Updating leaderboard records")
        leaderboard.records = get_records_from_dataset_repo(leaderboard.LEADERBOARD)
        return filter_table(model_id, dataset_id, issue_columns, info_columns, task)

    leaderboard_tab.select(
        fn=update_leaderboard_records, 
        inputs=[model_select, dataset_select, issue_columns_select, info_columns_select, task_select], 
        outputs=[leaderboard_df])

    @gr.on(
        triggers=[
            model_select.change,
            dataset_select.change,
            issue_columns_select.change,
            info_columns_select.change,
            task_select.change,
        ],
        inputs=[model_select, dataset_select, issue_columns_select, info_columns_select, task_select],
        outputs=[leaderboard_df],
    )
    def filter_table(model_id, dataset_id, issue_columns, info_columns, task):
        logger.info("Filtering leaderboard records")
        records = leaderboard.records
        # filter the table based on task
        df = records[(records["task"] == task)]
        # filter the table based on the model_id and dataset_id
        if model_id and model_id != "Any":
            df = df[(df["model_id"] == model_id)]
        if dataset_id and dataset_id != "Any":
            df = df[(df["dataset_id"] == dataset_id)]

        # filter the table based on the columns 
        issue_columns.sort()
        df = df[info_columns + issue_columns]
        types = get_types(df)
        display_df = get_display_df(df)
        return gr.update(value=display_df, datatype=types, interactive=False)