import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from annotated_text import annotated_text
from plot import Plot
from flow import FlowChart
import os

# Define initial threshold values at the top of the script
default_cause_threshold = 25
default_indicator_threshold = 15
default_cause_threshold_sankey = 20 
default_indicator_threshold_sankey = 15 

# Initialize Plots
plot = Plot()
flow_chart = FlowChart()

# Load the trained model and tokenizer
model_directory = "norygano/causalBERT"
tokenizer = AutoTokenizer.from_pretrained(model_directory, add_prefix_space=True)
model = AutoModelForTokenClassification.from_pretrained(model_directory)
model.eval()

# Define the label map
label_map = {0: "O", 1: "B-INDICATOR", 2: "I-INDICATOR", 3: "B-CAUSE", 4: "I-CAUSE", 5: "B-EFFECT", 6: "I-EFFECT"}

# Main application
st.markdown(
    """
    <div style="display: flex; align-items: center; justify-content: left; font-size: 60px; font-weight: bold;">
        <span>CAUSE</span>
        <span style="transform: rotate(270deg); display: inline-block; margin-left: 5px;">V</span>
    </div>
    """,
    unsafe_allow_html=True
)
st.markdown("[Weights](https://huggingface.co/norygano/causalBERT) | [Data](https://huggingface.co/datasets/norygano/causenv)")
st.write("Indicators and causes in explicit attributions of causality.")

# Create tabs
tab1, tab2, tab3, tab4, tab5 = st.tabs(["Prompt", "Indicators", "Causes", "Scatter", "Sankey"])

# Prompt Tab
with tab1:
    sentences_input = st.text_area("*Sentences (one per line)*", "\n".join([
        "Autos stehen im Verdacht, Waldsterben verursacht zu haben.",
        "Fußball führt zu Waldschäden.",
        "Haustüren tragen zum Betonsterben bei.",
    ]), placeholder="German only (currently)")

    sentences = [sentence.strip() for sentence in sentences_input.splitlines() if sentence.strip()]

    if st.button("Analyze"):
        for sentence in sentences:
            inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
            with torch.no_grad():
                outputs = model(**inputs)
            logits = outputs.logits
            predicted_label_ids = torch.argmax(logits, dim=2)
            tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
            predicted_labels = [label_map[label_id.item()] for label_id in predicted_label_ids[0]]

            annotations = []
            current_word = ""
            current_label = "O"
            for token, label in zip(tokens, predicted_labels):
                if token in ['[CLS]', '[SEP]']:  # Exclude special tokens
                    continue
                if token.startswith("##"):
                    current_word += token[2:]
                else:
                    if current_word:
                        if current_label != "O":
                            annotations.append((current_word, current_label))
                        else:
                            annotations.append(current_word)
                        annotations.append(" ")  # Add a space between words
                    current_word = token
                    current_label = label
            if current_word:
                if current_label != "O":
                    annotations.append((current_word, current_label))
                else:
                    annotations.append(current_word)
            st.write(f"**Sentence:** {sentence}")
            annotated_text(*annotations)
            st.write("---")


# Indicator Tab
with tab2:
    selected_chart_type = st.radio(
        label="Type",
        label_visibility='collapsed',
        options=['Total', 'Year', 'Individual'],
        horizontal=True,
    )

    # Display the chart in a container
    with st.container():
        if selected_chart_type == 'Individual':
            # Retrieve slider value from session state or use default
            individual_threshold = st.session_state.get("individual_threshold", default_indicator_threshold)
            fig = plot.get_indicator_chart(chart_type=selected_chart_type.lower(), individual_threshold=individual_threshold)
        else:
            fig = plot.get_indicator_chart(chart_type=selected_chart_type.lower())
        st.plotly_chart(fig, use_container_width=True)

    # Display the slider below the chart container for 'Individual' type
    if selected_chart_type == 'Individual':
        with st.container():
            individual_threshold = st.slider(
                "Indicator >=", 
                min_value=1, 
                max_value=95, 
                value=default_indicator_threshold, 
                key="individual_threshold"
            )

# Causes Tab
with tab3:
    # Create a container for the chart and place the slider below it
    with st.container():
        # Display the chart first
        fig_causes = plot.get_causes_chart(min_value=st.session_state.get("cause_threshold_causes", default_cause_threshold))
        st.plotly_chart(fig_causes, use_container_width=True)
        
        # Place the slider below the chart with a unique key
        cause_threshold_causes = st.slider(
            "Cause >=", min_value=1, max_value=75, value=default_cause_threshold, key="cause_threshold_causes"
        )

# Scatter Tab
with tab4:
    fig_scatter = plot.scatter()
    st.plotly_chart(fig_scatter, use_container_width=True)

# Sankey Tab
with tab5:
    with st.container():
        # Use the unique Sankey threshold variables in session state
        cause_threshold_sankey = st.session_state.get("cause_threshold_sankey", default_cause_threshold_sankey)
        indicator_threshold_sankey = st.session_state.get("indicator_threshold_sankey", default_indicator_threshold_sankey)
        
        # Generate the Sankey diagram with the new Sankey-specific thresholds
        fig_sankey = plot.sankey(cause_threshold=cause_threshold_sankey, indicator_threshold=indicator_threshold_sankey)
        st.plotly_chart(fig_sankey, use_container_width=True)
    # Place sliders below the chart container with unique keys for the Sankey tab
    with st.container():
        cause_threshold_sankey = st.slider(
            "Cause >=", min_value=1, max_value=100, value=default_cause_threshold_sankey, key="cause_threshold_sankey"
        )
        indicator_threshold_sankey = st.slider(
            "Indicator >=", min_value=1, max_value=100, value=default_indicator_threshold_sankey, key="indicator_threshold_sankey"
        )