import re
import os
import warnings

import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import streamlit as st
from charset_normalizer import detect
from transformers import (
    AutoModelForTokenClassification,
    AutoTokenizer,
    logging,
    pipeline,
)

warnings.simplefilter(action="ignore", category=Warning)
logging.set_verbosity(logging.ERROR)

st.set_page_config(page_title="Legal NER", page_icon="⚖️", layout="wide")

st.markdown(
    """
    <style>
        body {
            font-family: 'Poppins', sans-serif;
            background-color: #f4f4f8;
        }
        .header {
            background-color: rgba(220, 219, 219, 0.25);
            color: #000;
            padding: 5px 0;
            text-align: center;
            border-radius: 7px;
            margin-bottom: 13px;
            border-bottom: 2px solid #333;
        }
        .container {
            background-color: #fff;
            padding: 30px;
            border-radius: 10px;
            box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
            width: 100%;
            max-width: 1000px;
            margin: 0 auto;
            position: absolute;
            top: 50%;
            left: 50%;
            transform: translate(-50%, -50%);
        }
        .btn-primary {
            background-color: #5477d1;
            border: none;
            transition: background-color 0.3s, transform 0.2s;
            border-radius: 25px;
            box-shadow: 0 1px 3px rgba(0, 0, 0, 0.08);
        }
        .btn-primary:hover {
            background-color: #4c6cbe;
            transform: translateY(-1px);
        }
        h2 {
            font-weight: 600;
            font-size: 24px;
            margin-bottom: 20px;
        }
        label {
            font-weight: 500;
        }
        .tip {
            background-color: rgba(180, 47, 109, 0.25);
            padding: 7px;
            border-radius: 7px;
            display: inline-block;
            margin-top: 15px;
            margin-bottom: 15px;
        }
        .sec {
            background-color: rgba(220, 219, 219, 0.10);
            padding: 7px;
            border-radius: 5px;
            display: inline-block;
            margin-top: 15px;
            margin-bottom: 15px;
        }
        .tooltip {
            position: relative;
            display: inline-block;
            cursor: pointer;
        }
        .tooltip .tooltiptext {
            visibility: hidden;
            width: 120px;
            background-color: #6c757d;
            color: #fff;
            text-align: center;
            border-radius: 3px;
            padding: 3px;
            position: absolute;
            z-index: 1;
            bottom: 125%; 
            left: 50%; 
            margin-left: -60px;
            opacity: 0;
            transition: opacity 0.3s;
        }
        .tooltip:hover .tooltiptext {
            visibility: visible;
            opacity: 1;
        }
        .anonymized {
            background-color: #ffcccb;
            color: #000;
            font-weight: bold;
            border-radius: 3px;
            padding: 2px 4px;
        }
        #language-container {
            position: fixed;
            top: 10px;
            right: 10px;
            z-index: 1000;
        }
    </style>
""",
    unsafe_allow_html=True,
)

# UI text for English and German.
ui_text = {
    "EN": {
        "title": "Legal NER",
        "upload": "Upload a .txt file",
        "anonymize": "Anonymize",
        "select_entities": "Entity types to anonymize:",
        "download": "Download Anonymized Text",
        "tip": "Tip: Hover over the colored words to see its class.",
        "error": "An error occurred while processing the file: ",
    },
    "DE": {
        "title": "Juristische NER",
        "upload": "Lade eine .txt-Datei hoch",
        "anonymize": "Anonymisieren",
        "select_entities": "Entitätstypen zur Anonymisierung:",
        "download": "Anonymisierten Text herunterladen",
        "tip": "Tipp: Fahre mit der Maus über die farbigen Wörter, um deren Klasse zu sehen.",
        "error": "Beim Verarbeiten der Datei ist ein Fehler aufgetreten: ",
    },
}

col1, col2 = st.columns([4, 1])
with col2:
    lang = st.radio(
        "Language:",
        options=["EN", "DE"],
        horizontal=True,
        label_visibility="hidden",
        key="language_selector",
    )
with col1:
    st.title(ui_text[lang]["title"])

# Initialization for German Legal NER
tkn = os.getenv("tkn")
tokenizer = AutoTokenizer.from_pretrained("harshildarji/JuraNER", use_auth_token=tkn)
model = AutoModelForTokenClassification.from_pretrained(
    "harshildarji/JuraNER", use_auth_token=tkn
)
ner = pipeline("ner", model=model, tokenizer=tokenizer)

# Define class labels for the model
classes = {
    "AN": "Lawyer",
    "EUN": "European legal norm",
    "GRT": "Court",
    "GS": "Law",
    "INN": "Institution",
    "LD": "Country",
    "LDS": "Landscape",
    "LIT": "Legal literature",
    "MRK": "Brand",
    "ORG": "Organization",
    "PER": "Person",
    "RR": "Judge",
    "RS": "Court decision",
    "ST": "City",
    "STR": "Street",
    "UN": "Company",
    "VO": "Ordinance",
    "VS": "Regulation",
    "VT": "Contract",
}
ner_labels = list(classes.keys())


# Generate a list of colors for visualization
def generate_colors(num_colors):
    cm = plt.get_cmap("tab20")
    colors = [mcolors.rgb2hex(cm(1.0 * i / num_colors)) for i in range(num_colors)]
    return colors


# Color substrings based on NER results
def color_substrings(input_string, model_output):
    colors = generate_colors(len(ner_labels))
    label_to_color = {
        label: colors[i % len(colors)] for i, label in enumerate(ner_labels)
    }

    last_end = 0
    html_output = ""

    for entity in sorted(model_output, key=lambda x: x["start"]):
        start, end, label = entity["start"], entity["end"], entity["label"]
        html_output += input_string[last_end:start]
        tooltip = classes.get(label, "")
        html_output += (
            f'<span class="tooltip" style="color: {label_to_color.get(label)}; font-weight: bold;">'
            f'{input_string[start:end]}<span class="tooltiptext">{tooltip}</span></span>'
        )
        last_end = end

    html_output += input_string[last_end:]
    return html_output


# Selectively anonymize entities
def anonymize_text(input_string, model_output, selected_entities=None):
    merged_model_output = []
    sorted_entities = sorted(model_output, key=lambda x: x["start"])
    if sorted_entities:
        current = sorted_entities[0]
        for entity in sorted_entities[1:]:
            if (
                entity["label"] == current["label"]
                and input_string[current["end"] : entity["start"]].strip() == ""
            ):
                current["end"] = entity["end"]
                current["word"] = input_string[current["start"] : current["end"]]
            else:
                merged_model_output.append(current)
                current = entity
        merged_model_output.append(current)
    else:
        merged_model_output = sorted_entities

    anonymized_text = ""
    last_end = 0
    colors = generate_colors(len(ner_labels))
    label_to_color = {
        label: colors[i % len(colors)] for i, label in enumerate(ner_labels)
    }

    for entity in merged_model_output:
        start, end, label = entity["start"], entity["end"], entity["label"]
        anonymized_text += input_string[last_end:start]
        if selected_entities is None or label in selected_entities:
            anonymized_text += (
                f'<span class="anonymized">[{classes.get(label, label)}]</span>'
            )
        else:
            tooltip = classes.get(label, "")
            anonymized_text += (
                f'<span class="tooltip" style="color: {label_to_color.get(label)}; font-weight: bold;">'
                f'{input_string[start:end]}<span class="tooltiptext">{tooltip}</span></span>'
            )
        last_end = end

    anonymized_text += input_string[last_end:]
    return anonymized_text


def merge_entities(ner_results):
    merged_entities = []
    current_entity = None

    for token in ner_results:
        tag = token["entity"]
        entity_type = tag.split("-")[-1] if "-" in tag else tag
        token_start, token_end = token["start"], token["end"]
        token_word = token["word"].replace("##", "")  # Remove subword prefixes

        if (
            tag.startswith("B-")
            or current_entity is None
            or current_entity["label"] != entity_type
        ):
            if current_entity:
                merged_entities.append(current_entity)
            current_entity = {
                "start": token_start,
                "end": token_end,
                "label": entity_type,
                "word": token_word,
            }
        elif (
            tag.startswith("I-")
            and current_entity
            and current_entity["label"] == entity_type
        ):
            current_entity["end"] = token_end
            current_entity["word"] += token_word
        else:
            if (
                current_entity
                and token_start == current_entity["end"]
                and current_entity["label"] == entity_type
            ):
                current_entity["end"] = token_end
                current_entity["word"] += token_word
            else:
                if current_entity:
                    merged_entities.append(current_entity)
                current_entity = {
                    "start": token_start,
                    "end": token_end,
                    "label": entity_type,
                    "word": token_word,
                }

    if current_entity:
        merged_entities.append(current_entity)
    return merged_entities


uploaded_file = st.file_uploader(ui_text[lang]["upload"], type="txt")

if uploaded_file is not None:
    try:
        raw_content = uploaded_file.read()
        detected = detect(raw_content)
        encoding = detected["encoding"]
        if encoding is None:
            raise ValueError("Unable to detect file encoding.")

        lines = raw_content.decode(encoding).splitlines()

        line_results = []
        for line in lines:
            if line.strip():
                results = ner(line)
                merged_results = merge_entities(results)
                line_results.append(merged_results)
            else:
                line_results.append([])

        anonymize_mode = st.checkbox(ui_text[lang]["anonymize"])

        selected_entities = None
        if anonymize_mode:
            detected_entity_tags = set()
            for merged_results in line_results:
                for entity in merged_results:
                    detected_entity_tags.add(entity["label"])

            inverse_classes = {v: k for k, v in classes.items()}
            detected_options = sorted([classes[tag] for tag in detected_entity_tags])
            selected_options = st.multiselect(
                ui_text[lang]["select_entities"],
                options=detected_options,
                default=detected_options,
            )
            selected_entities = [
                inverse_classes[options] for options in selected_options
            ]

        st.markdown(
            "<hr style='margin-top: 10px; margin-bottom: 20px;'>",
            unsafe_allow_html=True,
        )

        anonymized_lines = []
        displayed_lines = []

        for line, merged_results in zip(lines, line_results):
            if line.strip():
                if anonymize_mode:
                    anonymized_text = anonymize_text(
                        line, merged_results, selected_entities=selected_entities
                    )
                    displayed_lines.append(anonymized_text)
                    plain_text = re.sub(r"<.*?>", "", anonymized_text)
                    anonymized_lines.append(plain_text.strip())
                else:
                    colored_html = color_substrings(line, merged_results)
                    st.markdown(f"{colored_html}", unsafe_allow_html=True)
            else:
                # displayed_lines.append("<br>")
                anonymized_lines.append("")

        if anonymize_mode:
            original_file_name = uploaded_file.name
            download_file_name = f"Anon_{original_file_name}"
            anonymized_content = "\n".join(anonymized_lines)
            for displayed_line in displayed_lines:
                st.markdown(f"{displayed_line}", unsafe_allow_html=True)
            st.markdown("<hr>", unsafe_allow_html=True)
            st.download_button(
                label=ui_text[lang]["download"],
                data=anonymized_content,
                file_name=download_file_name,
                mime="text/plain",
            )
        else:
            st.markdown("<hr>", unsafe_allow_html=True)
            st.markdown(
                f'<div class="tip"><strong>{ui_text[lang]["tip"]}</strong></div>',
                unsafe_allow_html=True,
            )

    except Exception as e:
        st.error(f"{ui_text[lang]['error']}{e}")