import streamlit as st
from annotated_text import annotated_text
from refined.inference.processor import Refined
import requests
import json
import spacy

# Page config
    page_title="Entity Linking by WordLift",
        'Get Help': '',
        'About': "# This is a demo app for NEL/NED/NER and SEO"

# Sidebar
language_options = {"English", "English - spaCy", "German"}
selected_language = st.sidebar.selectbox("Select the Language", list(language_options), index=0)

# Based on selected language, configure model, entity set, and citation options
if selected_language == "German" or selected_language == "English - spaCy":
    selected_model_name = None
    selected_entity_set = None

    entity_fishing_citation = """
    title = {entity-fishing},
    publisher = {GitHub},
    year = {2016--2023},
    archivePrefix = {swh},
    eprint = {1:dir:cb0ba3379413db12b0018b7c3af8d0d2d864139c}

    with st.sidebar.expander('Citations'):
    model_options = ["aida_model", "wikipedia_model_with_numbers"]
    entity_set_options = ["wikidata", "wikipedia"]
    selected_model_name = st.sidebar.selectbox("Select the Model", model_options)
    selected_entity_set = st.sidebar.selectbox("Select the Entity Set", entity_set_options)

    refined_citation = """
    title = "{R}e{F}in{ED}: An Efficient Zero-shot-capable Approach to End-to-End Entity Linking",
    author = "Tom Ayoola, Shubhi Tyagi, Joseph Fisher, Christos Christodoulopoulos, Andrea Pierleoni",
    booktitle = "NAACL",
    year = "2022"

    with st.sidebar.expander('Citations'):

@st.cache_resource  # 👈 Add the caching decorator
def load_model(selected_language, model_name=None, entity_set=None):
    if selected_language == "German":
        # Load the German-specific model
        nlp_model_de = spacy.load("de_core_news_lg")
        return nlp_model_de
    elif selected_language == "English - spaCy":
        # Load English-specific model
        nlp_model_en = spacy.load("en_core_web_sm")

        return nlp_model_en    
        # Load the pretrained model for other languages
        refined_model = Refined.from_pretrained(model_name=model_name, entity_set=entity_set)
        return refined_model

# Use the cached model
model = load_model(selected_language, selected_model_name, selected_entity_set)

# Helper functions
def get_wikidata_id(entity_string):
    entity_list = entity_string.split("=")
    entity_id = str(entity_list[1])
    entity_link = "" + entity_id
    return {"id": entity_id, "link": entity_link}
def get_entity_data(entity_link):
        # Format the entity_link
        formatted_link = entity_link.replace("http://", "http/")
        response = requests.get(f'{formatted_link}')
        return response.json()
    except Exception as e:
        print(f"Exception when fetching data for entity: {entity_link}. Exception: {e}")
        return None
# Create the form
with st.form(key='my_form'):
    text_input = st.text_area(label='Enter a sentence')
    submit_button = st.form_submit_button(label='Analyze')

# Initialization
entities_map = {}
entities_data = {}

if text_input:
    if selected_language in ["German", "English - spaCy"]:
        doc = model(text_input)
        entities = [(ent.text, ent.label_, ent._.kb_qid, ent._.url_wikidata) for ent in doc.ents]
        for entity in entities:
            entity_string, entity_type, wikidata_id, wikidata_url = entity
            if wikidata_url:
                # Ensure correct format for the German and English model
                formatted_wikidata_url = wikidata_url.replace("", "")
                entities_map[entity_string] = {"id": wikidata_id, "link": formatted_wikidata_url}
                entity_data = get_entity_data(formatted_wikidata_url)
                if entity_data is not None:
                    entities_data[entity_string] = entity_data
        entities = model.process_text(text_input)

        for entity in entities:
            single_entity_list = str(entity).strip('][').replace("\'", "").split(', ')
            if len(single_entity_list) >= 2 and "wikidata" in single_entity_list[1]:
                entities_map[single_entity_list[0].strip()] = get_wikidata_id(single_entity_list[1])
                entity_data = get_entity_data(entities_map[single_entity_list[0].strip()]["link"])
                if entity_data is not None:
                    entities_data[single_entity_list[0].strip()] = entity_data

    combined_entity_info_dictionary = dict([(k, [entities_map[k], entities_data[k] if k in entities_data else None]) for k in entities_map])
    if submit_button:
        # Prepare a list to hold the final output
        final_text = []
        # JSON-LD data
        json_ld_data = {
                "@context": "",
                "@type": "WebPage",
                "mentions": []
       # Replace each entity in the text with its annotated version
        for entity_string, entity_info in entities_map.items():
            # Check if the entity has a valid Wikidata link
            if entity_info["link"] is None or entity_info["link"] == "None":
                continue  # skip this entity
            entity_data = entities_data.get(entity_string, None)
            entity_type = None
            if entity_data is not None:
                entity_type = entity_data.get("@type", None)
            # Use different colors based on the entity's type
            color = "#8ef"  # Default color
            if entity_type == "Place":
                color = "#8AC7DB"
            elif entity_type == "Organization":
                color = "#ADD8E6"
            elif entity_type == "Person":
                color = "#67B7D1"
            elif entity_type == "Product":
                color = "#2ea3f2"
            elif entity_type == "CreativeWork":
                color = "#00BFFF"
            elif entity_type == "Event":
                color = "#1E90FF"
            entity_annotation = (entity_string, entity_info["id"], color)
            text_input = text_input.replace(entity_string, f'{{{str(entity_annotation)}}}', 1)
            # Add the entity to JSON-LD data
            entity_json_ld = combined_entity_info_dictionary[entity_string][1]
            if entity_json_ld and entity_json_ld.get("link") != "None":

        # Split the modified text_input into a list
        text_list = text_input.split("{")
        for item in text_list:
            if "}" in item:
                item_list = item.split("}")
                if len(item_list[1]) > 0:

        # Pass the final_text to the annotated_text function
        with st.expander("See annotations"):

        with st.expander("Here is the final JSON-LD"):
            st.json(json_ld_data)  # Output JSON-LD