# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import statistics

import json
import pandas as pd
import seaborn as sns
import streamlit as st
#from st_aggrid import AgGrid, GridOptionsBuilder

from .dataset_utils import HF_DESC_FIELD, HF_FEATURE_FIELD, HF_LABEL_FIELD
st.set_option('deprecation.showPyplotGlobalUse', False)
json_file_path = "cache_dir/has_cache.json"
with open(json_file_path, "r", encoding="utf-8") as j:
    _HAS_CACHE = json.loads(j.read())

def sidebar_header():
    st.sidebar.markdown(
        """
    This demo showcases the [dataset measures as we develop them](https://huggingface.co/blog/data-measurements-tool).
    Right now this has a few pre-loaded datasets for which you can:
    - view some general statistics about the text vocabulary, lengths, labels
    - explore some distributional statistics to assess properties of the language
    - view some comparison statistics and overview of the text distribution

    The tool is in development, and will keep growing in utility and functionality 🤗🚧
    """,
        unsafe_allow_html=True,
    )


def sidebar_selection(ds_name_to_dict, column_id):
    # ds_names = list(ds_name_to_dict.keys())
    ds_names = list(_HAS_CACHE.keys())
    with st.sidebar.expander(f"Choose dataset and field {column_id}", expanded=True):
        # choose a dataset to analyze
        ds_name = st.selectbox(
            f"Choose dataset to explore{column_id}:",
            ds_names,
            index=ds_names.index("hate_speech18"),
        )
        # choose a config to analyze
        ds_configs = ds_name_to_dict[ds_name]
        if ds_name == "c4":
            config_names = ['en','en.noblocklist','realnewslike']
        else:
            config_names = list(ds_configs.keys())
        config_names = list(_HAS_CACHE[ds_name].keys())
        config_name = st.selectbox(
            f"Choose configuration{column_id}:",
            config_names,
            index=0,
        )
        # choose a subset of num_examples
        # TODO: Handling for multiple text features
        #ds_config = ds_configs[config_name]
        # text_features = ds_config[HF_FEATURE_FIELD]["string"]
        text_features = [tuple(text_field.split('-')) for text_field in _HAS_CACHE[ds_name][config_name]]
        # TODO @yacine: Explain what this is doing and why eg tp[0] could = "id"
        text_field = st.selectbox(
            f"Which text feature from the{column_id} dataset would you like to analyze?",
            [("text",)]
            if ds_name == "c4"
            else [tp for tp in text_features if tp[0] != "id"],
        )
        # Choose a split and dataset size
        # avail_splits = list(ds_config["splits"].keys())
        avail_splits = list(_HAS_CACHE[ds_name][config_name]['-'.join(text_field)].keys())
        # 12.Nov note: Removing "test" because those should not be examined
        # without discussion of pros and cons, which we haven't done yet.
        if "test" in avail_splits:
            avail_splits.remove("test")
        split = st.selectbox(
            f"Which split from the{column_id} dataset would you like to analyze?",
            avail_splits,
            index=0,
        )
        label_field, label_names = (
            ds_name_to_dict[ds_name][config_name][HF_FEATURE_FIELD][HF_LABEL_FIELD][0]
            if len(
                ds_name_to_dict[ds_name][config_name][HF_FEATURE_FIELD][HF_LABEL_FIELD]
            )
            > 0
            else ((), [])
        )
        return {
            "dset_name": ds_name,
            "dset_config": config_name,
            "split_name": split,
            "text_field": text_field,
            "label_field": label_field,
            "label_names": label_names,
        }


def expander_header(dstats, ds_name_to_dict, column_id):
    with st.expander(f"Dataset Description{column_id}"):
        st.markdown(
            ds_name_to_dict[dstats.dset_name][dstats.dset_config][HF_DESC_FIELD]
        )
        st.dataframe(dstats.dset_peek)


def expander_general_stats(dstats, column_id):
    with st.expander(f"General Text Statistics{column_id}"):
        st.caption(
            "Use this widget to check whether the terms you see most represented"
            " in the dataset make sense for the goals of the dataset."
        )
        if dstats.total_words == 0:
            st.markdown("Eh oh...not finding the file I need. 😭  Probably it will be there soon. 🤞 Check back later!")
        else:
            st.markdown("There are {0} total words".format(str(dstats.total_words)))
            st.markdown(
                "There are {0} words after removing closed "
                "class words".format(str(dstats.total_open_words))
            )
            st.markdown(
                "The most common "
                "[open class words](https://dictionary.apa.org/open-class-words) "
                "and their counts are: "
            )
            st.dataframe(dstats.sorted_top_vocab_df)
            st.markdown(
                "There are {0} missing values in the dataset.".format(
                    str(dstats.text_nan_count)
                )
            )
            if dstats.dedup_total > 0:
                st.markdown(
                    "There are {0} duplicate items in the dataset. "
                    "For more information about the duplicates, "
                    "click the 'Duplicates' tab below.".format(str(dstats.dedup_total))
                )
            else:
                st.markdown("There are 0 duplicate items in the dataset. ")


### Show the label distribution from the datasets
def expander_label_distribution(fig_labels, column_id):
    with st.expander(f"Label Distribution{column_id}", expanded=False):
        st.caption(
            "Use this widget to see how balanced the labels in your dataset are."
        )
        if fig_labels is not None:
            st.plotly_chart(fig_labels, use_container_width=True)
        else:
            st.markdown("No labels were found in the dataset")


def expander_text_lengths(dstats, column_id):
    _TEXT_LENGTH_CAPTION = (
        "Use this widget to identify outliers, particularly suspiciously long outliers."
    )
    with st.expander(f"Text Lengths{column_id}", expanded=False):
        st.caption(_TEXT_LENGTH_CAPTION)
        st.markdown(
            "Below, you can see how the lengths of the text instances in your dataset are distributed."
        )
        st.markdown(
            "Any unexpected peaks or valleys in the distribution may help to identify instances you want to remove or augment."
        )
        st.markdown(
            "### Here is the relative frequency of different text lengths in your dataset:"
        )
        try:
            st.image(dstats.fig_tok_length_png)
        except:
            st.pyplot(dstats.fig_tok_length, use_container_width=True)
        st.markdown(
            "The average length of text instances is **"
            + str(dstats.avg_length)
            + " words**, with a standard deviation of **"
            + str(dstats.std_length)
            + "**."
        )
        # This is quite a large file and is breaking our ability to navigate the app development.
        # Just passing if it's not already there for launch v0
        if dstats.length_df is not None:
            start_id_show_lengths = st.selectbox(
                "Show examples of length:",
                sorted(dstats.length_df["length"].unique().tolist()),
                key=f"select_show_length_{column_id}",
            )
            st.table(
                dstats.length_df[
                    dstats.length_df["length"] == start_id_show_lengths
                ].set_index("length")
            )


### Third, use a sentence embedding model
def expander_text_embeddings(
    text_dset, fig_tree, node_list, embeddings, text_field, column_id
):
    with st.expander(f"Text Embedding Clusters{column_id}", expanded=False):
        _EMBEDDINGS_CAPTION = """
        ### Hierarchical Clustering of Text Fields
        Taking in the diversity of text represented in a dataset can be
        challenging when it is made up of hundreds to thousands of sentences.
        Grouping these text items based on a measure of similarity can help
        users gain some insights into their distribution.
        The following figure shows a hierarchical clustering of the text fields
        in the dataset based on a
        [Sentence-Transformer](https://hf.co/sentence-transformers/all-mpnet-base-v2)
        model. Clusters are merged if any of the embeddings in cluster A has a
        dot product with any of the embeddings or with the centroid of cluster B
        higher than a threshold (one threshold per level, from 0.5 to 0.95).
        To explore the clusters, you can:
        - hover over a node to see the 5 most representative examples (deduplicated)
        - enter an example in the text box below to see which clusters it is most similar to
        - select a cluster by ID to show all of its examples
        """
        st.markdown(_EMBEDDINGS_CAPTION)
        st.plotly_chart(fig_tree, use_container_width=True)
        st.markdown("---\n")
        if st.checkbox(
            label="Enter text to see nearest clusters",
            key=f"search_clusters_{column_id}",
        ):
            compare_example = st.text_area(
                label="Enter some text here to see which of the clusters in the dataset it is closest to",
                key=f"search_cluster_input_{column_id}",
            )
            if compare_example != "":
                paths_to_leaves = embeddings.cached_clusters.get(
                    compare_example,
                    embeddings.find_cluster_beam(compare_example, beam_size=50),
                )
                clusters_intro = ""
                if paths_to_leaves[0][1] < 0.3:
                    clusters_intro += (
                        "**Warning: no close clusters found (best score <0.3). **"
                    )
                clusters_intro += "The closest clusters to the text entered aboce are:"
                st.markdown(clusters_intro)
                for path, score in paths_to_leaves[:5]:
                    example = text_dset[
                        node_list[path[-1]]["sorted_examples_centroid"][0][0]
                    ][text_field][:256]
                    st.write(
                        f"Cluster {path[-1]:5d} | Score: {score:.3f}  \n Example: {example}"
                    )
                show_node_default = paths_to_leaves[0][0][-1]
            else:
                show_node_default = len(node_list) // 2
        else:
            show_node_default = len(node_list) // 2
        st.markdown("---\n")
        if text_dset is None:
            st.markdown("Missing source text to show, check back later!")
        else:
            show_node = st.selectbox(
                f"Choose a leaf node to explore in the{column_id} dataset:",
                range(len(node_list)),
                index=show_node_default,
            )
            node = node_list[show_node]
            start_id = st.slider(
                f"Show closest sentences in cluster to the centroid{column_id} starting at index:",
                0,
                len(node["sorted_examples_centroid"]) - 5,
                value=0,
                step=5,
            )
            for sid, sim in node["sorted_examples_centroid"][start_id : start_id + 5]:
                # only show the first 4 lines and the first 10000 characters
                show_text = text_dset[sid][text_field][:10000]
                show_text = "\n".join(show_text.split("\n")[:4])
                st.text(f"{sim:.3f} \t {show_text}")


### Then, show duplicates
def expander_text_duplicates(dstats, column_id):
    # TODO: Saving/loading figure
    with st.expander(f"Text Duplicates{column_id}", expanded=False):
        st.caption(
            "Use this widget to identify text strings that appear more than once."
        )
        st.markdown(
            "A model's training and testing may be negatively affected by unwarranted duplicates ([Lee et al., 2021](https://arxiv.org/abs/2107.06499))."
        )
        st.markdown("------")
        st.write(
            "### Here is the list of all the duplicated items and their counts in your dataset:"
        )
        if dstats.dup_counts_df is None or dstats.dup_counts_df.empty:
            st.write("There are no duplicates in this dataset! 🥳")
        else:
            st.dataframe(dstats.dup_counts_df.reset_index(drop=True))


def expander_npmi_description(min_vocab):
    _NPMI_CAPTION = (
        "Use this widget to identify problematic biases and stereotypes in your data."
    )
    _NPMI_CAPTION1 = """
    nPMI scores for a word help to identify potentially
    problematic associations, ranked by how close the association is."""
    _NPMI_CAPTION2 = """
    nPMI bias scores for paired words help to identify how word
    associations are skewed between the selected selected words
    ([Aka et al., 2021](https://arxiv.org/abs/2103.03417)).
    """

    st.caption(_NPMI_CAPTION)
    st.markdown(_NPMI_CAPTION1)
    st.markdown(_NPMI_CAPTION2)
    st.markdown("  ")
    st.markdown(
        "You can select from gender and sexual orientation "
        "identity terms that appear in the dataset at least %s "
        "times." % min_vocab
    )
    st.markdown(
        "The resulting ranked words are those that co-occur with both "
        "identity terms.  "
    )
    st.markdown(
        "The more *positive* the score, the more associated the word is with the first identity term.  "
        "The more *negative* the score, the more associated the word is with the second identity term."
    )


### Finally, show Zipf stuff
def expander_zipf(z, zipf_fig, column_id):
    with st.expander(
        f"Vocabulary Distribution{column_id}: Zipf's Law Fit", expanded=False
    ):
        try:
            _ZIPF_CAPTION = """This shows how close the observed language is to an ideal
            natural language distribution following [Zipf's law](https://en.wikipedia.org/wiki/Zipf%27s_law),
            calculated by minimizing the [Kolmogorov-Smirnov (KS) statistic](https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test)."""

            powerlaw_eq = r"""p(x) \propto x^{- \alpha}"""
            zipf_summary = (
                    "The optimal alpha based on this dataset is: **"
                    + str(round(z.alpha, 2))
                    + "**, with a KS distance of: **"
                    + str(round(z.distance, 2))
            )
            zipf_summary += (
                    "**.  This was fit with a minimum rank value of: **"
                    + str(int(z.xmin))
                    + "**, which is the optimal rank *beyond which* the scaling regime of the power law fits best."
            )

            alpha_warning = "Your alpha value is a bit on the high side, which means that the distribution over words in this dataset is a bit unnatural. This could be due to non-language items throughout the dataset."
            xmin_warning = "The minimum rank for this fit is a bit on the high side, which means that the frequencies of your most common words aren't distributed as would be expected by Zipf's law."
            fit_results_table = pd.DataFrame.from_dict(
                {
                    r"Alpha:": [str("%.2f" % z.alpha)],
                    "KS distance:": [str("%.2f" % z.distance)],
                    "Min rank:": [str("%s" % int(z.xmin))],
                },
                columns=["Results"],
                orient="index",
            )
            fit_results_table.index.name = column_id
            st.caption(
                "Use this widget for the counts of different words in your dataset, measuring the difference between the observed count and the expected count under Zipf's law."
            )
            st.markdown(_ZIPF_CAPTION)
            st.write(
                """
            A Zipfian distribution follows the power law: $p(x) \propto x^{-α}$
    with an ideal α value of 1."""
            )
            st.markdown(
                "In general, an alpha greater than 2 or a minimum rank greater than 10 (take with a grain of salt) means that your distribution is relativaly _unnatural_ for natural language. This can be a sign of mixed artefacts in the dataset, such as HTML markup."
            )
            st.markdown(
                "Below, you can see the counts of each word in your dataset vs. the expected number of counts following a Zipfian distribution."
            )
            st.markdown("-----")
            st.write("### Here is your dataset's Zipf results:")
            st.dataframe(fit_results_table)
            st.write(zipf_summary)
            # TODO: Nice UI version of the content in the comments.
            # st.markdown("\nThe KS test p-value is < %.2f" % z.ks_test.pvalue)
            # if z.ks_test.pvalue < 0.01:
            #    st.markdown(
            #        "\n Great news! Your data fits a powerlaw with a minimum KS " "distance of %.4f" % z.distance)
            # else:
            #    st.markdown("\n Sadly, your data does not fit a powerlaw. =(")
            # st.markdown("Checking the goodness of fit of our observed distribution")
            # st.markdown("to the hypothesized power law distribution")
            # st.markdown("using a Kolmogorov–Smirnov (KS) test.")
            st.plotly_chart(zipf_fig, use_container_width=True)
            if z.alpha > 2:
                st.markdown(alpha_warning)
            if z.xmin > 5:
                st.markdown(xmin_warning)
        except:
            st.write("Under construction! 😱 🚧")


### Finally finally finally, show nPMI stuff.
def npmi_widget(npmi_stats, min_vocab, column_id):
    """
    Part of the main app, but uses a user interaction so pulled out as its own f'n.
    :param use_cache:
    :param column_id:
    :param npmi_stats:
    :param min_vocab:
    :return:
    """
    with st.expander(f"Word Association{column_id}: nPMI", expanded=False):
        try:
            if len(npmi_stats.available_terms) > 0:
                expander_npmi_description(min_vocab)
                st.markdown("-----")
                term1 = st.selectbox(
                    f"What is the first term you want to select?{column_id}",
                    npmi_stats.available_terms,
                )
                term2 = st.selectbox(
                    f"What is the second term you want to select?{column_id}",
                    reversed(npmi_stats.available_terms),
                )
                # We calculate/grab nPMI data based on a canonical (alphabetic)
                # subgroup ordering.
                subgroup_pair = sorted([term1, term2])
                try:
                    joint_npmi_df = npmi_stats.load_or_prepare_joint_npmi(subgroup_pair)
                    npmi_show(joint_npmi_df)
                except KeyError:
                    st.markdown(
                        "**WARNING!** The nPMI for these terms has not been pre-computed, please re-run caching."
                    )
            else:
                st.markdown(
                    "No words found co-occurring with both of the selected identity terms."
                )
        except:
            st.write("Under construction! 😱 🚧")


def npmi_show(paired_results):
    if paired_results.empty:
        st.markdown("No words that co-occur enough times for results!  Or there's a 🐛.  Or we're still computing this one. 🤷")
    else:
        s = pd.DataFrame(paired_results.sort_values(by="npmi-bias", ascending=True))
        # s.columns=pd.MultiIndex.from_arrays([['npmi','npmi','npmi','count', 'count'],['bias','man','straight','man','straight']])
        s.index.name = "word"
        npmi_cols = s.filter(like="npmi").columns
        count_cols = s.filter(like="count").columns
        if s.shape[0] > 10000:
            bias_thres = max(abs(s["npmi-bias"][5000]), abs(s["npmi-bias"][-5000]))
            print(f"filtering with bias threshold: {bias_thres}")
            s_filtered = s[s["npmi-bias"].abs() > bias_thres]
        else:
            s_filtered = s
        # TODO: This is very different look than the duplicates table above. Should probably standardize.
        cm = sns.palplot(sns.diverging_palette(270, 36, s=99, l=48, n=16))
        out_df = (
            s_filtered.style.background_gradient(subset=npmi_cols, cmap=cm)
            .format(subset=npmi_cols, formatter="{:,.3f}")
            .format(subset=count_cols, formatter=int)
            .set_properties(
                subset=count_cols, **{"width": "10em", "text-align": "center"}
            )
            .set_properties(**{"align": "center"})
            .set_caption(
                "nPMI scores and co-occurence counts between the selected identity terms and the words they both co-occur with"
            )
        )  # s = pd.read_excel("output.xlsx", index_col="word")
        st.write("### Here is your dataset's nPMI results:")
        st.dataframe(out_df)


### Dumping unused functions here for now
### Second, show the distribution of text perplexities
def expander_text_perplexities(text_label_df, sorted_sents_loss, fig_loss):
    with st.expander("Show text perplexities A", expanded=False):
        st.markdown("### Text perplexities A")
        st.plotly_chart(fig_loss, use_container_width=True)
        start_id_show_loss = st.slider(
            "Show highest perplexity sentences in A starting at index:",
            0,
            text_label_df.shape[0] - 5,
            value=0,
            step=5,
        )
        for lss, sent in sorted_sents_loss[start_id_show_loss : start_id_show_loss + 5]:
            st.text(f"{lss:.3f} {sent}")