File size: 6,495 Bytes
04d4fc6
 
 
4be4c1c
adb4a34
 
45d0933
04d4fc6
9525dec
adb4a34
6b22889
7e996cc
 
 
 
 
adb4a34
9525dec
04d4fc6
 
 
 
 
 
 
60e75a3
04d4fc6
45d0933
 
3a965ad
 
2e68043
3a965ad
 
 
45d0933
3a965ad
adb4a34
7e996cc
04d4fc6
45d0933
632a488
3a965ad
45d0933
 
 
7e996cc
45d0933
 
6b22889
04d4fc6
45d0933
04d4fc6
45d0933
 
 
 
 
 
 
 
 
04d4fc6
45d0933
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04d4fc6
7e996cc
 
45d0933
7e996cc
632a488
 
7e996cc
 
 
 
 
 
 
 
 
 
 
 
 
04d4fc6
7e996cc
 
 
 
 
 
 
 
 
 
 
 
45d0933
7e996cc
 
 
 
 
 
 
 
 
 
04d4fc6
7e996cc
45d0933
7e996cc
9525dec
 
7e996cc
9525dec
 
7e996cc
 
 
 
 
 
9525dec
7e996cc
9525dec
7e996cc
 
 
 
 
adb4a34
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from annotated_text import annotated_text
from plot import Plot
from flow import FlowChart
import os

# Define initial threshold values at the top of the script
default_cause_threshold = 25
default_indicator_threshold = 15
default_cause_threshold_sankey = 20 
default_indicator_threshold_sankey = 15 

# Initialize Plots
plot = Plot()
flow_chart = FlowChart()

# Load the trained model and tokenizer
model_directory = "norygano/causalBERT"
tokenizer = AutoTokenizer.from_pretrained(model_directory, add_prefix_space=True)
model = AutoModelForTokenClassification.from_pretrained(model_directory)
model.eval()

# Define the label map
label_map = {0: "O", 1: "B-INDICATOR", 2: "I-INDICATOR", 3: "B-CAUSE", 4: "I-CAUSE", 5: "B-EFFECT", 6: "I-EFFECT"}

# Main application
st.markdown(
    """
    <div style="display: flex; align-items: center; justify-content: left; font-size: 60px; font-weight: bold;">
        <span>CAUSE</span>
        <span style="transform: rotate(270deg); display: inline-block; margin-left: 5px;">V</span>
    </div>
    """,
    unsafe_allow_html=True
)
st.markdown("[Weights](https://huggingface.co/norygano/causalBERT) | [Data](https://huggingface.co/datasets/norygano/causenv)")
st.write("Indicators and causes in explicit attributions of causality.")

# Create tabs
tab1, tab2, tab3, tab4, tab5 = st.tabs(["Prompt", "Indicators", "Causes", "Scatter", "Sankey"])

# Prompt Tab
with tab1:
    sentences_input = st.text_area("*Sentences (one per line)*", "\n".join([
        "Autos stehen im Verdacht, Waldsterben verursacht zu haben.",
        "Fußball führt zu Waldschäden.",
        "Haustüren tragen zum Betonsterben bei.",
    ]), placeholder="German only (currently)")

    sentences = [sentence.strip() for sentence in sentences_input.splitlines() if sentence.strip()]

    if st.button("Analyze"):
        for sentence in sentences:
            inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
            with torch.no_grad():
                outputs = model(**inputs)
            logits = outputs.logits
            predicted_label_ids = torch.argmax(logits, dim=2)
            tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
            predicted_labels = [label_map[label_id.item()] for label_id in predicted_label_ids[0]]

            annotations = []
            current_word = ""
            current_label = "O"
            for token, label in zip(tokens, predicted_labels):
                if token in ['[CLS]', '[SEP]']:  # Exclude special tokens
                    continue
                if token.startswith("##"):
                    current_word += token[2:]
                else:
                    if current_word:
                        if current_label != "O":
                            annotations.append((current_word, current_label))
                        else:
                            annotations.append(current_word)
                        annotations.append(" ")  # Add a space between words
                    current_word = token
                    current_label = label
            if current_word:
                if current_label != "O":
                    annotations.append((current_word, current_label))
                else:
                    annotations.append(current_word)
            st.write(f"**Sentence:** {sentence}")
            annotated_text(*annotations)
            st.write("---")


# Indicator Tab
with tab2:
    selected_chart_type = st.radio(
        label="Type",
        label_visibility='collapsed',
        options=['Total', 'Year', 'Individual'],
        horizontal=True,
    )

    # Display the chart in a container
    with st.container():
        if selected_chart_type == 'Individual':
            # Retrieve slider value from session state or use default
            individual_threshold = st.session_state.get("individual_threshold", default_indicator_threshold)
            fig = plot.get_indicator_chart(chart_type=selected_chart_type.lower(), individual_threshold=individual_threshold)
        else:
            fig = plot.get_indicator_chart(chart_type=selected_chart_type.lower())
        st.plotly_chart(fig, use_container_width=True)

    # Display the slider below the chart container for 'Individual' type
    if selected_chart_type == 'Individual':
        with st.container():
            individual_threshold = st.slider(
                "Indicator >=", 
                min_value=1, 
                max_value=95, 
                value=default_indicator_threshold, 
                key="individual_threshold"
            )

# Causes Tab
with tab3:
    # Create a container for the chart and place the slider below it
    with st.container():
        # Display the chart first
        fig_causes = plot.get_causes_chart(min_value=st.session_state.get("cause_threshold_causes", default_cause_threshold))
        st.plotly_chart(fig_causes, use_container_width=True)
        
        # Place the slider below the chart with a unique key
        cause_threshold_causes = st.slider(
            "Cause >=", min_value=1, max_value=75, value=default_cause_threshold, key="cause_threshold_causes"
        )

# Scatter Tab
with tab4:
    fig_scatter = plot.scatter()
    st.plotly_chart(fig_scatter, use_container_width=True)

# Sankey Tab
with tab5:
    with st.container():
        # Use the unique Sankey threshold variables in session state
        cause_threshold_sankey = st.session_state.get("cause_threshold_sankey", default_cause_threshold_sankey)
        indicator_threshold_sankey = st.session_state.get("indicator_threshold_sankey", default_indicator_threshold_sankey)
        
        # Generate the Sankey diagram with the new Sankey-specific thresholds
        fig_sankey = plot.sankey(cause_threshold=cause_threshold_sankey, indicator_threshold=indicator_threshold_sankey)
        st.plotly_chart(fig_sankey, use_container_width=True)
    # Place sliders below the chart container with unique keys for the Sankey tab
    with st.container():
        cause_threshold_sankey = st.slider(
            "Cause >=", min_value=1, max_value=100, value=default_cause_threshold_sankey, key="cause_threshold_sankey"
        )
        indicator_threshold_sankey = st.slider(
            "Indicator >=", min_value=1, max_value=100, value=default_indicator_threshold_sankey, key="indicator_threshold_sankey"
        )