import json
import logging

import datasets
import huggingface_hub
import pandas as pd
from transformers import pipeline
import requests
import os

logger = logging.getLogger(__name__)
HF_WRITE_TOKEN = "HF_WRITE_TOKEN"
AUTH_CHECK_URL = "https://huggingface.co/api/whoami-v2"

logger = logging.getLogger(__file__)

class HuggingFaceInferenceAPIResponse:
    def __init__(self, message):
        self.message = message


def get_labels_and_features_from_dataset(ds):
    try:
        dataset_features = ds.features
        label_keys = [i for i in dataset_features.keys() if i.startswith('label')]
        if len(label_keys) == 0: # no labels found
            # return everything for post processing
            return list(dataset_features.keys()), list(dataset_features.keys())
        if not isinstance(dataset_features[label_keys[0]], datasets.ClassLabel):
            if hasattr(dataset_features[label_keys[0]], 'feature'):
                label_feat = dataset_features[label_keys[0]].feature
                labels = label_feat.names
        else:
            labels = dataset_features[label_keys[0]].names
        features = [f for f in dataset_features.keys() if not f.startswith("label")]
        return labels, features
    except Exception as e:
        logging.warning(
            f"Get Labels/Features Failed for dataset: {e}"
        )
        return None, None

def check_model_task(model_id):
    # check if model is valid on huggingface
    try:
        task = huggingface_hub.model_info(model_id).pipeline_tag
        if task is None:
            return None
        return task
    except Exception:
        return None

def get_model_labels(model_id, example_input):
    hf_token = os.environ.get(HF_WRITE_TOKEN, default="")
    payload = {"inputs": example_input, "options": {"use_cache": True}}
    response = hf_inference_api(model_id, hf_token, payload)
    if "error" in response:
        return None
    return extract_from_response(response, "label")

def extract_from_response(data, key):
    results = []

    if isinstance(data, dict):
        res = data.get(key)
        if res is not None:
            results.append(res)

        for value in data.values():
            results.extend(extract_from_response(value, key))

    elif isinstance(data, list):
        for element in data:
            results.extend(extract_from_response(element, key))

    return results

def hf_inference_api(model_id, hf_token, payload):
    hf_inference_api_endpoint = os.environ.get(
        "HF_INFERENCE_ENDPOINT", default="https://api-inference.huggingface.co"
    )
    url = f"{hf_inference_api_endpoint}/models/{model_id}"
    headers = {"Authorization": f"Bearer {hf_token}"}
    response = requests.post(url, headers=headers, json=payload)
    if not hasattr(response, "status_code") or response.status_code != 200:
        logger.warning(f"Request to inference API returns {response}")
    try:
        return response.json()
    except Exception:
        return {"error": response.content}
    
def preload_hf_inference_api(model_id):
    payload = {"inputs": "This is a test", "options": {"use_cache": True, }}
    hf_token = os.environ.get(HF_WRITE_TOKEN, default="")
    hf_inference_api(model_id, hf_token, payload)

def check_model_pipeline(model_id):
    try:
        task = huggingface_hub.model_info(model_id).pipeline_tag
    except Exception:
        return None

    try:
        ppl = pipeline(task=task, model=model_id)

        return ppl
    except Exception:
        return None


def text_classificaiton_match_label_case_unsensative(id2label_mapping, label):
    for model_label in id2label_mapping.keys():
        if model_label.upper() == label.upper():
            return model_label, label
    return None, label


def text_classification_map_model_and_dataset_labels(id2label, dataset_features):
    id2label_mapping = {id2label[k]: None for k in id2label.keys()}
    dataset_labels = None
    for feature in dataset_features.values():
        if not isinstance(feature, datasets.ClassLabel):
            continue
        if len(feature.names) != len(id2label_mapping.keys()):
            continue

        dataset_labels = feature.names
        # Try to match labels
        for label in feature.names:
            if label in id2label_mapping.keys():
                model_label = label
            else:
                # Try to find case unsensative
                model_label, label = text_classificaiton_match_label_case_unsensative(
                    id2label_mapping, label
                )
            if model_label is not None:
                id2label_mapping[model_label] = label
            else:
                print(f"Label {label} is not found in model labels")

    return id2label_mapping, dataset_labels


"""
params:
    column_mapping: dict
    example: {
        "text": "sentences",
        "label": {
            "label0": "LABEL_0",
            "label1": "LABEL_1"
        }
    }
    ppl: pipeline
"""


def check_column_mapping_keys_validity(column_mapping, ppl):
    # get the element in all the list elements
    column_mapping = json.loads(column_mapping)
    if "data" not in column_mapping.keys():
        return True
    user_labels = set([pair[0] for pair in column_mapping["data"]])
    model_labels = set([pair[1] for pair in column_mapping["data"]])

    id2label = ppl.model.config.id2label
    original_labels = set(id2label.values())

    return user_labels == model_labels == original_labels


"""
params:
    column_mapping: dict
    dataset_features: dict
    example: {
        'text': Value(dtype='string', id=None), 
        'label': ClassLabel(names=['negative', 'neutral', 'positive'], id=None)
    }
"""


def infer_text_input_column(column_mapping, dataset_features):
    # Check whether we need to infer the text input column
    infer_text_input_column = True
    feature_map_df = None

    if "text" in column_mapping.keys():
        dataset_text_column = column_mapping["text"]
        if dataset_text_column in dataset_features.keys():
            infer_text_input_column = False
        else:
            logging.warning(f"Provided {dataset_text_column} is not in Dataset columns")

    if infer_text_input_column:
        # Try to retrieve one
        candidates = [
            f for f in dataset_features if dataset_features[f].dtype == "string"
        ]
        feature_map_df = pd.DataFrame(
            {"Dataset Features": [candidates[0]], "Model Input Features": ["text"]}
        )
        if len(candidates) > 0:
            logging.debug(f"Candidates are {candidates}")
            column_mapping["text"] = candidates[0]

    return column_mapping, feature_map_df


"""
params:
    column_mapping: dict
    id2label_mapping: dict
    example:
    id2label_mapping: {
        'negative': 'negative', 
        'neutral': 'neutral', 
        'positive': 'positive'
        }
"""


def infer_output_label_column(
    column_mapping, id2label_mapping, id2label, dataset_labels
):
    # Check whether we need to infer the output label column
    if "data" in column_mapping.keys():
        if isinstance(column_mapping["data"], list):
            # Use the column mapping passed by user
            for user_label, model_label in column_mapping["data"]:
                id2label_mapping[model_label] = user_label
    elif None in id2label_mapping.values():
        column_mapping["label"] = {i: None for i in id2label.keys()}
        return column_mapping, None

    if "data" not in column_mapping.keys():
        # Column mapping should contain original model labels
        column_mapping["label"] = {
            str(i): id2label_mapping[label]
            for i, label in zip(id2label.keys(), dataset_labels)
        }

    id2label_df = pd.DataFrame(
        {
            "Dataset Labels": dataset_labels,
            "Model Prediction Labels": [
                id2label_mapping[label] for label in dataset_labels
            ],
        }
    )

    return column_mapping, id2label_df


def check_dataset_features_validity(d_id, config, split):
    # We assume dataset is ok here
    ds = datasets.load_dataset(d_id, config, split=split, trust_remote_code=True)
    try:
        dataset_features = ds.features
    except AttributeError:
        # Dataset does not have features, need to provide everything
        return None, None
        # Load dataset as DataFrame
    df = ds.to_pandas()

    return df, dataset_features

def select_the_first_string_column(ds):
    for feature in ds.features.keys():
        if isinstance(ds[0][feature], str):
            return feature
    return None


def get_example_prediction(model_id, dataset_id, dataset_config, dataset_split, hf_token):
    # get a sample prediction from the model on the dataset
    prediction_input = None
    prediction_result = None
    try:
        # Use the first item to test prediction
        ds = datasets.load_dataset(dataset_id, dataset_config, split=dataset_split, trust_remote_code=True)
        if "text" not in ds.features.keys():
            # Dataset does not have text column
            prediction_input = ds[0][select_the_first_string_column(ds)]
        else:
            prediction_input = ds[0]["text"]

        payload = {"inputs": prediction_input, "options": {"use_cache": True}}
        results = hf_inference_api(model_id, hf_token, payload)

        if isinstance(results, dict) and "error" in results.keys():
            if "estimated_time" in results.keys():
                return prediction_input, HuggingFaceInferenceAPIResponse(
                    f"Estimated time: {int(results['estimated_time'])}s. Please try again later.")
            return prediction_input, HuggingFaceInferenceAPIResponse(
                f"Inference Error: {results['error']}.")
        
        while isinstance(results, list):
            if isinstance(results[0], dict):
                break
            results = results[0]
        prediction_result = {
            f'{result["label"]}': result["score"] for result in results
        }
    except Exception as e:
        # inference api prediction failed, show the error message
        logger.error(f"Get example prediction failed {e}")
        return prediction_input, None

    return prediction_input, prediction_result


def get_sample_prediction(ppl, df, column_mapping, id2label_mapping):
    # get a sample prediction from the model on the dataset
    prediction_input = None
    prediction_result = None
    try:
        # Use the first item to test prediction
        prediction_input = df.head(1).at[0, column_mapping["text"]]
        results = ppl({"text": prediction_input}, top_k=None)
        prediction_result = {
            f'{result["label"]}': result["score"] for result in results
        }
    except Exception:
        # Pipeline prediction failed, need to provide labels
        return prediction_input, None

    # Display results in original label and mapped label
    prediction_result = {
        f'{result["label"]}(original) - {id2label_mapping[result["label"]]}(mapped)': result[
            "score"
        ]
        for result in results
    }
    return prediction_input, prediction_result


def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, split):
    # load dataset as pd DataFrame
    # get features column from dataset
    df, dataset_features = check_dataset_features_validity(d_id, config, split)

    column_mapping, feature_map_df = infer_text_input_column(
        column_mapping, dataset_features
    )
    if feature_map_df is None:
        # dataset does not have any features
        return None, None, None, None, None

    # Retrieve all labels
    id2label = ppl.model.config.id2label

    # Infer labels
    id2label_mapping, dataset_labels = text_classification_map_model_and_dataset_labels(
        id2label, dataset_features
    )
    column_mapping, id2label_df = infer_output_label_column(
        column_mapping, id2label_mapping, id2label, dataset_labels
    )
    if id2label_df is None:
        # does not able to infer output label column
        return column_mapping, None, None, None, feature_map_df

    # Get a sample prediction
    prediction_input, prediction_result = get_sample_prediction(
        ppl, df, column_mapping, id2label_mapping
    )
    if prediction_result is None:
        # does not able to get a sample prediction
        return column_mapping, prediction_input, None, id2label_df, feature_map_df

    return (
        column_mapping,
        prediction_input,
        prediction_result,
        id2label_df,
        feature_map_df,
    )

def strip_model_id_from_url(model_id):
    if model_id.startswith("https://huggingface.co/"):
        return "/".join(model_id.split("/")[-2:])
    return model_id

def check_hf_token_validity(hf_token):
    if hf_token == "":
        return False
    if not isinstance(hf_token, str):
        return False
    # use huggingface api to check the token
    headers = {"Authorization": f"Bearer {hf_token}"}
    response = requests.get(AUTH_CHECK_URL, headers=headers)
    if response.status_code != 200:
        return False
    return True