# coding=utf-8
# Copyright 2023 The GlotLID Authors.
# Lint as: python3


# This space is built based on AMR-KELEG/ALDi space.
# GlotLID Space

import string
import constants
import pandas as pd
import streamlit as st
from huggingface_hub import hf_hub_download
from GlotScript import get_script_predictor
import matplotlib.pyplot as plt
import fasttext
import altair as alt
from altair import X, Y, Scale
import base64
import json
import os
import re

@st.cache_resource
def load_sp():
    sp = get_script_predictor()
    return sp


sp = load_sp()

def get_script(text):
    """Get the writing systems of given text.

    Args:
        text: The text to be preprocessed.

    Returns:
        The main script and list of all scripts.
    """
    res = sp(text)
    main_script = res[0] if res[0] else 'Zyyy'
    all_scripts_dict = res[2]['details']
    if all_scripts_dict:
        all_scripts = list(all_scripts_dict.keys())
    else:
        all_scripts = 'Zyyy'

    return main_script, all_scripts


def preprocess_text(text):
    """Apply preprocessing to the given text.
    Args:
        text: Thetext to be preprocessed.
    Returns:
        The preprocessed text.
    """

    # remove \n
    text = text.replace('\n', ' ')

    # get rid of characters that are ubiquitous
    replace_by = " " 
    replacement_map = {
        ord(c): replace_by
        for c in string.punctuation + string.digits
    }
    text = text.translate(replacement_map)

    # make multiple space one space
    text = re.sub(r'\s+', ' ', text)

    # strip the text
    text = text.strip()

    return text


@st.cache_data
def language_names(json_path):
    with open(json_path, 'r') as json_file:
        data = json.load(json_file)
    return data

label2name = language_names("assets/language_names.json")

def get_name(label):
    """Get the name of language from label"""
    iso_3 = label.split('_')[0]
    name = label2name[iso_3]
    return name


@st.cache_data
def render_svg(svg):
    """Renders the given svg string."""
    b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8")
    html = rf'<p align="center"> <img src="data:image/svg+xml;base64,{b64}", width="40%"/> </p>'
    c = st.container()
    c.write(html, unsafe_allow_html=True)


@st.cache_data
def render_metadata():
    """Renders the metadata."""
    html = r"""<p align="center">
        <a href="https://huggingface.co/cis-lmu/glotlid"><img alt="HuggingFace Model" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-8A2BE2"></a>
        <a href="https://github.com/cisnlp/GlotLID"><img alt="GitHub" src="https://img.shields.io/badge/%F0%9F%93%A6%20GitHub-orange"></a>
        <a href="https://github.com/cisnlp/GlotLID/blob/main/LICENSE"><img alt="GitHub license" src="https://img.shields.io/github/license/cisnlp/GlotLID?logoColor=blue"></a>
        <a href="https://github.com/cisnlp/GlotLID"><img alt="GitHub stars" src="https://img.shields.io/github/stars/cisnlp/GlotLID"></a>
        <a href="https://arxiv.org/abs/2310.16248"><img alt="arXiv" src="https://img.shields.io/badge/arXiv-2310.16248-b31b1b.svg"></a>
        </p>"""
    c = st.container()
    c.write(html, unsafe_allow_html=True)

@st.cache_data
def citation():
    """Renders the metadata."""
    _CITATION  = """
    @inproceedings{
      kargaran2023glotlid,
      title={GlotLID: Language Identification for Low-Resource Languages},
      author={Kargaran, Amir Hossein and Imani, Ayyoob and Yvon, Fran{\c{c}}ois and Sch{\"u}tze, Hinrich},
      booktitle={The 2023 Conference on Empirical Methods in Natural Language Processing},
      year={2023},
      url={https://openreview.net/forum?id=dl4e3EBz5j}
    }"""
    st.code(_CITATION, language="python", line_numbers=False)


@st.cache_data
def convert_df(df):
    # IMPORTANT: Cache the conversion to prevent computation on every rerun
    return df.to_csv(index=None).encode("utf-8")


@st.cache_resource
def load_GlotLID_v1(model_name, file_name):
    model_path = hf_hub_download(repo_id=model_name, filename=file_name)
    model = fasttext.load_model(model_path)
    return model

@st.cache_resource
def load_GlotLID_v2(model_name, file_name):
    model_path = hf_hub_download(repo_id=model_name, filename=file_name)
    model = fasttext.load_model(model_path)
    return model


model_1 = load_GlotLID_v1(constants.MODEL_NAME, "model_v1.bin")
model_2 = load_GlotLID_v2(constants.MODEL_NAME, "model_v2.bin")

# @st.cache_resource
def plot(label, prob):

    ORANGE_COLOR = "#FF8000"
    BLACK_COLOR = "#31333F"
    fig, ax = plt.subplots(figsize=(8, 1))
    fig.patch.set_facecolor("none")
    ax.set_facecolor("none")

    ax.spines["left"].set_color(BLACK_COLOR)
    ax.spines["bottom"].set_color(BLACK_COLOR)
    ax.tick_params(axis="x", colors=BLACK_COLOR)

    ax.spines[["right", "top"]].set_visible(False)

    ax.barh(y=[0], width=[prob], color=ORANGE_COLOR)
    ax.set_xlim(0, 1)
    ax.set_ylim(-1, 1)
    ax.set_title(f"Label: {label}, Language: {get_name(label)}", color=BLACK_COLOR)
    ax.get_yaxis().set_visible(False)
    ax.set_xlabel("Confidence", color=BLACK_COLOR)
    st.pyplot(fig)

def compute(sentences, version = 'v2'):
    """Computes the language probablities and labels for the given sentences.

    Args:
        sentences: A list of sentences.

    Returns:
        A list of language probablities and labels for the given sentences.
    """
    progress_text = "Computing Language..."
    model_choice = model_2 if version == 'v2' else model_1
    my_bar = st.progress(0, text=progress_text)

    probs = []
    labels = []

    sentences = [preprocess_text(sent) for sent in sentences]
    
    for index, sent in enumerate(sentences):

        output = model_choice.predict(sent)
        
        output_label  = output[0][0].split('__')[-1]
        output_prob = max(min(output[1][0], 1), 0) 
        output_label_language = output_label.split('_')[0]

        # script control
        if version in ['v2'] and output_label_language!= 'zxx':
            main_script, all_scripts = get_script(sent)
            output_label_script = output_label.split('_')[1]

            if output_label_script not in all_scripts:
                output_label_script = main_script
                output_label = f"und_{output_label_script}"
                output_prob = 0

    
        labels = labels + [output_label]
        probs = probs + [output_prob]

        my_bar.progress(
            min((index) / len(sentences), 1),
            text=progress_text,
        )
    my_bar.empty()
    return probs, labels

st.markdown("[![Duplicate Space](https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14)](https://huggingface.co/spaces/cis-lmu/glotlid-space?duplicate=true)")

render_svg(open("assets/glotlid_logo.svg").read())

render_metadata()

st.markdown("**GlotLID** is an open-source language identification model with support for more than **1600 languages**.")


tab1, tab2 = st.tabs(["Input a Sentence", "Upload a File"])

with tab1:
    
    # choice = st.radio(
    #     "Set granularity level",
    #     ["default", "merge", "individual"],
    #     captions=["enable both macrolanguage and its varieties (default)", "merge macrolanguage and its varieties into one label", "remove macrolanguages - only shows individual langauges"],
    # )

    version = st.radio(
        "Choose model",
        ["v1", "v2"],
        captions=["GlotLID version 1", "GlotLID version 2 (more data and languages)"],
        index = 1,
        key = 'version_tab1',
        horizontal = True
    )
    
    sent = st.text_input(
        "Sentence:", placeholder="Enter a sentence.", on_change=None
    )

    # TODO: Check if this is needed!

    clicked = st.button("Submit")

    if sent:
        
        probs, labels = compute([sent], version=version)
        prob = probs[0]
        label = labels[0]

        
        # Check if the file exists
        if not os.path.exists('logs.txt'):
            with open('logs.txt', 'w') as file:
                pass

        print(f"{sent}, {label}: {prob}")
        with open("logs.txt", "a") as f:
            f.write(f"{sent}, {label}: {prob}\n")
        
        # plot
        plot(label, prob)
        

with tab2:

    version = st.radio(
        "Choose model",
        ["v1", "v2"],
        captions=["GlotLID version 1", "GlotLID version 2 (more data and languages)"],
        index = 1,
        key = 'version_tab2',
        horizontal = True
    )

    file = st.file_uploader("Upload a file", type=["txt"])
    if file is not None:
        df = pd.read_csv(file, sep="¦\t¦", header=None, engine='python')
        df.columns = ["Sentence"]
        df.reset_index(drop=True, inplace=True)

        # TODO: Run the model
        df['Prob'], df["Label"] = compute(df["Sentence"].tolist(), version= version)
        df['Language'] = df["Label"].apply(get_name)

        # A horizontal rule
        st.markdown("""---""")

        chart = (
            alt.Chart(df.reset_index())
            .mark_area(color="darkorange", opacity=0.5)
            .encode(
                x=X(field="index", title="Sentence Index"),
                y=Y("Prob", scale=Scale(domain=[0, 1])),
            )
        )
        st.altair_chart(chart.interactive(), use_container_width=True)

        col1, col2 = st.columns([4, 1])

        with col1:
            # Display the output
            st.table(
                df,
            )

        with col2:
            # Add a download button
            csv = convert_df(df)
            st.download_button(
                label=":file_folder: Download predictions as CSV",
                data=csv,
                file_name="GlotLID.csv",
                mime="text/csv",
            )



# citation()