from datasets import load_dataset
from disaggregators import Disaggregator, DisaggregationModuleLabels, CustomDisaggregator
from disaggregators.disaggregation_modules.age import Age, AgeLabels, AgeConfig

import matplotlib
matplotlib.use('TKAgg')

import joblib
import os

cache_file = "cached_data.pkl"
cache_dict = {}

if os.path.exists(cache_file):
    cache_dict = joblib.load("cached_data.pkl")

class MeSHAgeLabels(AgeLabels):
    INFANT = "infant"
    CHILD_PRESCHOOL = "child_preschool"
    CHILD = "child"
    ADOLESCENT = "adolescent"
    ADULT = "adult"
    MIDDLE_AGED = "middle_aged"
    AGED = "aged"
    AGED_80_OVER = "aged_80_over"


age = Age(
    config=AgeConfig(
        labels=MeSHAgeLabels,
        ages=[list(MeSHAgeLabels)],
        breakpoints=[0, 2, 5, 12, 18, 44, 64, 79]
    ),
    column="question"
)


class TabsSpacesLabels(DisaggregationModuleLabels):
    TABS = "tabs"
    SPACES = "spaces"


class TabsSpaces(CustomDisaggregator):
    module_id = "tabs_spaces"
    labels = TabsSpacesLabels

    def __call__(self, row, *args, **kwargs):
        if "\t" in row[self.column]:
            return {self.labels.TABS: True, self.labels.SPACES: False}
        else:
            return {self.labels.TABS: False, self.labels.SPACES: True}


class ReactComponentLabels(DisaggregationModuleLabels):
    CLASS = "class"
    FUNCTION = "function"


class ReactComponent(CustomDisaggregator):
    module_id = "react_component"
    labels = ReactComponentLabels

    def __call__(self, row, *args, **kwargs):
        if "extends React.Component" in row[self.column] or "extends Component" in row[self.column]:
            return {self.labels.CLASS: True, self.labels.FUNCTION: False}
        else:
            return {self.labels.CLASS: False, self.labels.FUNCTION: True}


configs = {
    "laion": {
        "disaggregation_modules": ["continent"],
        "dataset_name": "society-ethics/laion2B-en_continents",
        "column": "TEXT",
        "feature_names": {
            "continent.africa": "Africa",
            "continent.americas": "Americas",
            "continent.asia": "Asia",
            "continent.europe": "Europe",
            "continent.oceania": "Oceania",

            # Parent level
            "continent": "Continent",
        }
    },
    "medmcqa": {
        "disaggregation_modules": [age, "gender"],
        "dataset_name": "society-ethics/medmcqa_age_gender_custom",
        "column": "question",
        "feature_names": {
            "age.infant": "Infant",
            "age.child_preschool": "Preschool",
            "age.child": "Child",
            "age.adolescent": "Adolescent",
            "age.adult": "Adult",
            "age.middle_aged": "Middle Aged",
            "age.aged": "Aged",
            "age.aged_80_over": "Aged 80+",
            "gender.male": "Male",
            "gender.female": "Female",

            # Parent level
            "gender": "Gender",
            "age": "Age",
            "Both": "Age + Gender",
        }
    },
    "stack": {
        "disaggregation_modules": [TabsSpaces, ReactComponent],
        "dataset_name": "society-ethics/the-stack-tabs_spaces",
        "column": "content",
        "feature_names": {
            "react_component.class": "Class",
            "react_component.function": "Function",
            "tabs_spaces.tabs": "Tabs",
            "tabs_spaces.spaces": "Spaces",

            # Parent level
            "react_component": "React Component Syntax",
            "tabs_spaces": "Tabs vs. Spaces",
            "Both": "React Component Syntax + Tabs vs. Spaces",
        }
    }
}


def generate_cached_data(disaggregation_modules, dataset_name, column, feature_names):
    disaggregator = Disaggregator(disaggregation_modules, column=column)
    ds = load_dataset(dataset_name, split="train")
    df = ds.to_pandas()

    all_fields = {*disaggregator.fields, "None"}
    distributions = df[sorted(list(disaggregator.fields))].value_counts()

    return {
        "fields": all_fields,
        "data_fields": disaggregator.fields,
        "distributions": distributions,
        "disaggregators": [module.name for module in disaggregator.modules],
        "column": column,
        "feature_names": feature_names,
    }


cache_dict.update({
    "laion": generate_cached_data(**configs["laion"]),
    "medmcqa": generate_cached_data(**configs["medmcqa"]),
    "stack": generate_cached_data(**configs["stack"])
})

joblib.dump(cache_dict, cache_file)