# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from dataclasses import asdict
from os.path import exists

import pandas as pd
from datasets import Dataset, get_dataset_infos, load_dataset, load_from_disk

# treating inf values as NaN as well
pd.set_option("use_inf_as_na", True)

## String names used in Hugging Face dataset configs.
HF_FEATURE_FIELD = "features"
HF_LABEL_FIELD = "label"
HF_DESC_FIELD = "description"

CACHE_DIR = "cache_dir"
## String names we are using within this code.
# These are not coming from the stored dataset nor HF config,
# but rather used as identifiers in our dicts and dataframes.
OUR_TEXT_FIELD = "text"
OUR_LABEL_FIELD = "label"
TOKENIZED_FIELD = "tokenized_text"
EMBEDDING_FIELD = "embedding"
LENGTH_FIELD = "length"
VOCAB = "vocab"
WORD = "word"
CNT = "count"
PROP = "proportion"
TEXT_NAN_CNT = "text_nan_count"
TXT_LEN = "text lengths"
DEDUP_TOT = "dedup_total"
TOT_WORDS = "total words"
TOT_OPEN_WORDS = "total open words"

_DATASET_LIST = [
    "c4",
    "squad",
    "squad_v2",
    "hate_speech18",
    "hate_speech_offensive",
    "glue",
    "super_glue",
    "wikitext",
    "imdb",
    "HuggingFaceM4/OBELICS",
]

_STREAMABLE_DATASET_LIST = [
    "c4",
    "wikitext",
    "HuggingFaceM4/OBELICS",
]

_MAX_ROWS = 100


def load_truncated_dataset(
    dataset_name,
    config_name,
    split_name,
    num_rows=_MAX_ROWS,
    cache_name=None,
    use_cache=True,
    use_streaming=True,
):
    """
    This function loads the first `num_rows` items of a dataset for a
    given `config_name` and `split_name`.
    If `cache_name` exists, the truncated dataset is loaded from `cache_name`.
    Otherwise, a new truncated dataset is created and immediately saved
    to `cache_name`.
    When the dataset is streamable, we iterate through the first
    `num_rows` examples in streaming mode, write them to a jsonl file,
    then create a new dataset from the json.
    This is the most direct way to make a Dataset from an IterableDataset
    as of datasets version 1.6.1.
    Otherwise, we download the full dataset and select the first
    `num_rows` items
    Args:
        dataset_name (string):
            dataset id in the dataset library
        config_name (string):
            dataset configuration
        split_name (string):
            split name
        num_rows (int):
            number of rows to truncate the dataset to
        cache_name (string):
            name of the cache directory
        use_cache (bool):
            whether to load form the cache if it exists
        use_streaming (bool):
            whether to use streaming when the dataset supports it
    Returns:
        Dataset: the truncated dataset as a Dataset object
    """
    if cache_name is None:
        cache_name = f"{dataset_name}_{config_name}_{split_name}_{num_rows}"
    if exists(cache_name):
        dataset = load_from_disk(cache_name)
    else:
        if use_streaming and dataset_name in _STREAMABLE_DATASET_LIST:
            iterable_dataset = load_dataset(
                dataset_name,
                name=config_name,
                split=split_name,
                streaming=True,
            ).take(num_rows)
            rows = list(iterable_dataset)
            f = open("temp.jsonl", "w", encoding="utf-8")
            for row in rows:
                _ = f.write(json.dumps(row) + "\n")
            f.close()
            dataset = Dataset.from_json(
                "temp.jsonl", features=iterable_dataset.features, split=split_name
            )
        else:
            full_dataset = load_dataset(
                dataset_name,
                name=config_name,
                split=split_name,
            )
            dataset = full_dataset.select(range(num_rows))
        dataset.save_to_disk(cache_name)
    return dataset


def intersect_dfs(df_dict):
    started = 0
    new_df = None
    for key, df in df_dict.items():
        if df is None:
            continue
        for key2, df2 in df_dict.items():
            if df2 is None:
                continue
            if key == key2:
                continue
            if started:
                new_df = new_df.join(df2, how="inner", lsuffix="1", rsuffix="2")
            else:
                new_df = df.join(df2, how="inner", lsuffix="1", rsuffix="2")
                started = 1
    return new_df.copy()


def get_typed_features(features, ftype="string", parents=None):
    """
    Recursively get a list of all features of a certain dtype
    :param features:
    :param ftype:
    :param parents:
    :return: a list of tuples > e.g. ('A', 'B', 'C') for feature example['A']['B']['C']
    """
    if parents is None:
        parents = []
    typed_features = []
    for name, feat in features.items():
        if isinstance(feat, dict):
            if feat.get("dtype", None) == ftype or feat.get("feature", {}).get(
                ("dtype", None) == ftype
            ):
                typed_features += [tuple(parents + [name])]
            elif "feature" in feat:
                if feat["feature"].get("dtype", None) == ftype:
                    typed_features += [tuple(parents + [name])]
                elif isinstance(feat["feature"], dict):
                    typed_features += get_typed_features(
                        feat["feature"], ftype, parents + [name]
                    )
            else:
                for k, v in feat.items():
                    if isinstance(v, dict):
                        typed_features += get_typed_features(
                            v, ftype, parents + [name, k]
                        )
        elif name == "dtype" and feat == ftype:
            typed_features += [tuple(parents)]
    return typed_features


def get_label_features(features, parents=None):
    """
    Recursively get a list of all features that are ClassLabels
    :param features:
    :param parents:
    :return: pairs of tuples as above and the list of class names
    """
    if parents is None:
        parents = []
    label_features = []
    for name, feat in features.items():
        if isinstance(feat, dict):
            if "names" in feat:
                label_features += [(tuple(parents + [name]), feat["names"])]
            elif "feature" in feat:
                if "names" in feat:
                    label_features += [
                        (tuple(parents + [name]), feat["feature"]["names"])
                    ]
                elif isinstance(feat["feature"], dict):
                    label_features += get_label_features(
                        feat["feature"], parents + [name]
                    )
            else:
                for k, v in feat.items():
                    if isinstance(v, dict):
                        label_features += get_label_features(v, parents + [name, k])
        elif name == "names":
            label_features += [(tuple(parents), feat)]
    return label_features


# get the info we need for the app sidebar in dict format
def dictionarize_info(dset_info):
    info_dict = asdict(dset_info)
    res = {
        "config_name": info_dict["config_name"],
        "splits": {
            spl: 100 #spl_info["num_examples"]
            for spl, spl_info in info_dict["splits"].items()
        },
        "features": {
            "string": get_typed_features(info_dict["features"], "string"),
            "int32": get_typed_features(info_dict["features"], "int32"),
            "float32": get_typed_features(info_dict["features"], "float32"),
            "label": get_label_features(info_dict["features"]),
        },
        "description": dset_info.description,
    }
    return res


def get_dataset_info_dicts(dataset_id=None):
    """
    Creates a dict from dataset configs.
    Uses the datasets lib's get_dataset_infos
    :return: Dictionary mapping dataset names to their configurations
    """
    if dataset_id != None:
        ds_name_to_conf_dict = {
            dataset_id: {
                config_name: dictionarize_info(config_info)
                for config_name, config_info in get_dataset_infos(dataset_id).items()
            }
        }
    else:
        ds_name_to_conf_dict = {
            ds_id: {
                config_name: dictionarize_info(config_info)
                for config_name, config_info in get_dataset_infos(ds_id).items()
            }
            for ds_id in _DATASET_LIST
        }
    return ds_name_to_conf_dict


# get all instances of a specific field in a dataset
def extract_field(examples, field_path, new_field_name=None):
    if new_field_name is None:
        new_field_name = "_".join(field_path)
    field_list = []
    # TODO: Breaks the CLI if this isn't checked.
    if isinstance(field_path, str):
        field_path = [field_path]
    item_list = examples[field_path[0]]
    for field_name in field_path[1:]:
        item_list = [
            next_item
            for item in item_list
            for next_item in (
                item[field_name]
                if isinstance(item[field_name], list)
                else [item[field_name]]
            )
        ]
    field_list += [
        field
        for item in item_list
        for field in (item if isinstance(item, list) else [item])
    ]
    return {new_field_name: field_list}