import os
import re
import time
from pathlib import Path

import requests
import streamlit as st
from spacy import displacy
from streamlit_extras.badges import badge
from streamlit_extras.stylable_container import stylable_container

# RELIK = os.getenv("RELIK", "localhost:8000/api/entities")

import random

from relik.inference.annotator import Relik


def get_random_color(ents):
    colors = {}
    random_colors = generate_pastel_colors(len(ents))
    for ent in ents:
        colors[ent] = random_colors.pop(random.randint(0, len(random_colors) - 1))
    return colors


def floatrange(start, stop, steps):
    if int(steps) == 1:
        return [stop]
    return [
        start + float(i) * (stop - start) / (float(steps) - 1) for i in range(steps)
    ]


def hsl_to_rgb(h, s, l):
    def hue_2_rgb(v1, v2, v_h):
        while v_h < 0.0:
            v_h += 1.0
        while v_h > 1.0:
            v_h -= 1.0
        if 6 * v_h < 1.0:
            return v1 + (v2 - v1) * 6.0 * v_h
        if 2 * v_h < 1.0:
            return v2
        if 3 * v_h < 2.0:
            return v1 + (v2 - v1) * ((2.0 / 3.0) - v_h) * 6.0
        return v1

    # if not (0 <= s <= 1): raise ValueError, "s (saturation) parameter must be between 0 and 1."
    # if not (0 <= l <= 1): raise ValueError, "l (lightness) parameter must be between 0 and 1."

    r, b, g = (l * 255,) * 3
    if s != 0.0:
        if l < 0.5:
            var_2 = l * (1.0 + s)
        else:
            var_2 = (l + s) - (s * l)
        var_1 = 2.0 * l - var_2
        r = 255 * hue_2_rgb(var_1, var_2, h + (1.0 / 3.0))
        g = 255 * hue_2_rgb(var_1, var_2, h)
        b = 255 * hue_2_rgb(var_1, var_2, h - (1.0 / 3.0))

    return int(round(r)), int(round(g)), int(round(b))


def generate_pastel_colors(n):
    """Return different pastel colours.

    Input:
        n (integer) : The number of colors to return

    Output:
        A list of colors in HTML notation (eg.['#cce0ff', '#ffcccc', '#ccffe0', '#f5ccff', '#f5ffcc'])

    Example:
        >>> print generate_pastel_colors(5)
        ['#cce0ff', '#f5ccff', '#ffcccc', '#f5ffcc', '#ccffe0']
    """
    if n == 0:
        return []

    # To generate colors, we use the HSL colorspace (see http://en.wikipedia.org/wiki/HSL_color_space)
    start_hue = 0.0  # 0=red    1/3=0.333=green   2/3=0.666=blue
    saturation = 1.0
    lightness = 0.9
    # We take points around the chromatic circle (hue):
    # (Note: we generate n+1 colors, then drop the last one ([:-1]) because
    # it equals the first one (hue 0 = hue 1))
    return [
        "#%02x%02x%02x" % hsl_to_rgb(hue, saturation, lightness)
        for hue in floatrange(start_hue, start_hue + 1, n + 1)
    ][:-1]


def set_sidebar(css):
    white_link_wrapper = "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='{}'>{}</a>"
    with st.sidebar:
        st.markdown(f"<style>{css}</style>", unsafe_allow_html=True)
        st.image(
            "http://nlp.uniroma1.it/static/website/sapienza-nlp-logo-wh.svg",
            use_column_width=True,
        )
        st.markdown("## ReLiK")
        st.write(
            f"""
                - {white_link_wrapper.format("#", "<i class='fa-solid fa-file'></i>&nbsp; Paper")}
                - {white_link_wrapper.format("https://github.com/SapienzaNLP/relik", "<i class='fa-brands fa-github'></i>&nbsp; GitHub")}
                - {white_link_wrapper.format("https://hub.docker.com/repository/docker/sapienzanlp/relik", "<i class='fa-brands fa-docker'></i>&nbsp; Docker Hub")}
                """,
            unsafe_allow_html=True,
        )
        st.markdown("## Sapienza NLP")
        st.write(
            f"""
                - {white_link_wrapper.format("https://nlp.uniroma1.it", "<i class='fa-solid fa-globe'></i>&nbsp; Webpage")}
                - {white_link_wrapper.format("https://github.com/SapienzaNLP", "<i class='fa-brands fa-github'></i>&nbsp; GitHub")}
                - {white_link_wrapper.format("https://twitter.com/SapienzaNLP", "<i class='fa-brands fa-twitter'></i>&nbsp; Twitter")}
                - {white_link_wrapper.format("https://www.linkedin.com/company/79434450", "<i class='fa-brands fa-linkedin'></i>&nbsp; LinkedIn")}
                """,
            unsafe_allow_html=True,
        )


def get_el_annotations(response):
    el_link_wrapper = "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='https://en.wikipedia.org/wiki/{}' style='color: #414141'><i class='fa-brands fa-wikipedia-w fa-xs'></i> <span style='font-size: 1.0em; font-family: monospace'> {}</span></a>"
    # swap labels key with ents
    ents = [
        {
            "start": l.start,
            "end": l.end,
            "label": el_link_wrapper.format(l.label.replace(" ", "_"), l.label),
        }
        for l in response.labels
    ]
    dict_of_ents = {"text": response.text, "ents": ents}
    label_in_text = set(l["label"] for l in dict_of_ents["ents"])
    options = {"ents": label_in_text, "colors": get_random_color(label_in_text)}
    return dict_of_ents, options


@st.cache_resource()
def load_model():
    return Relik(
        question_encoder="/home/user/app/models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder",
        document_index="/home/user/app/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index_filtered",
        reader="/home/user/app/models/relik-reader-aida-deberta-small",
        top_k=100,
        window_size=32,
        window_stride=16,
        candidates_preprocessing_fn="relik.inference.preprocessing.wikipedia_title_and_openings_preprocessing",
    )


def set_intro(css):
    # intro
    st.markdown("# ReLik")
    st.markdown(
        "### Retrieve, Read and LinK: Fast and Accurate Entity Linking and Relation Extraction on an Academic Budget"
    )
    # st.markdown(
    #     "This is a front-end for the paper [Universal Semantic Annotator: the First Unified API "
    #     "for WSD, SRL and Semantic Parsing](https://www.researchgate.net/publication/360671045_Universal_Semantic_Annotator_the_First_Unified_API_for_WSD_SRL_and_Semantic_Parsing), which will be presented at LREC 2022 by "
    #     "[Riccardo Orlando](https://riccorl.github.io), [Simone Conia](https://c-simone.github.io/), "
    #     "[Stefano Faralli](https://corsidilaurea.uniroma1.it/it/users/stefanofaralliuniroma1it), and [Roberto Navigli](https://www.diag.uniroma1.it/navigli/)."
    # )
    badge(type="github", name="sapienzanlp/relik")
    badge(type="pypi", name="relik")


def run_client():
    with open(Path(__file__).parent / "style.css") as f:
        css = f.read()

    st.set_page_config(
        page_title="ReLik",
        page_icon="🦮",
        layout="wide",
    )
    set_sidebar(css)
    set_intro(css)

    # text input
    text = st.text_area(
        "Enter Text Below:",
        value="Michael Jordan was one of the best players in the NBA.",
        height=200,
        max_chars=1500,
    )

    with stylable_container(
        key="annotate_button",
        css_styles="""
            button {
                background-color: #802433;
                color: white;
                border-radius: 25px;
            }
            """,
    ):
        submit = st.button("Annotate")
    # submit = st.button("Run")

    if "relik_model" not in st.session_state.keys():
        st.session_state["relik_model"] = load_model()
    relik_model = st.session_state["relik_model"]

    # ReLik API call
    if submit:
        text = text.strip()
        if text:
            st.markdown("####")
            st.markdown("#### Entity Linking")
            with st.spinner(text="In progress"):
                response = relik_model(text)
                # response = requests.post(RELIK, json=text)
                # if response.status_code != 200:
                #     st.error("Error: {}".format(response.status_code))
                # else:
                #     response = response.json()

                # st.markdown("##")
                dict_of_ents, options = get_el_annotations(response=response)
                display = displacy.render(
                    dict_of_ents, manual=True, style="ent", options=options
                )
                display = display.replace("\n", " ")
                # heurstic, prevents split of annotation decorations
                display = display.replace("border-radius: 0.35em;", "border-radius: 0.35em; white-space: nowrap;")
                with st.container():
                    st.write(display, unsafe_allow_html=True)

        else:
            st.error("Please enter some text.")


if __name__ == "__main__":
    run_client()