File size: 5,133 Bytes
04d4fc6
 
 
4be4c1c
45d0933
 
9525dec
45d0933
04d4fc6
9525dec
 
6b22889
9525dec
04d4fc6
 
 
 
 
 
 
60e75a3
04d4fc6
45d0933
 
3a965ad
 
 
 
 
 
45d0933
3a965ad
45d0933
6b22889
04d4fc6
45d0933
9525dec
3a965ad
45d0933
 
 
 
 
 
6b22889
04d4fc6
45d0933
04d4fc6
45d0933
 
 
 
 
 
 
 
 
04d4fc6
45d0933
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04d4fc6
45d0933
 
 
 
 
 
 
 
 
 
 
04d4fc6
45d0933
 
 
04d4fc6
45d0933
9525dec
 
 
 
 
 
 
 
 
 
 
 
 
 
6b22889
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from annotated_text import annotated_text
import pandas as pd
import plotly.express as px
from plot import indicator_chart, causes_chart, scatter, sankey
import os

# Define initial threshold values at the top of the script
default_cause_threshold = 20
default_indicator_threshold = 15

# Load the trained model and tokenizer
model_directory = "norygano/causalBERT"
tokenizer = AutoTokenizer.from_pretrained(model_directory, add_prefix_space=True)
model = AutoModelForTokenClassification.from_pretrained(model_directory)
model.eval()

# Define the label map
label_map = {0: "O", 1: "B-INDICATOR", 2: "I-INDICATOR", 3: "B-CAUSE", 4: "I-CAUSE", 5: "B-EFFECT", 6: "I-EFFECT"}

# Main application
st.markdown(
    """
    <div style="display: flex; align-items: center; justify-content: left; font-size: 60px; font-weight: bold;">
        <span>CAUSEN</span>
        <span style="transform: rotate(270deg); display: inline-block; margin-left: 5px;">V</span>
    </div>
    """,
    unsafe_allow_html=True
)
st.markdown("[Model](https://huggingface.co/norygano/causalBERT) | [Data](https://huggingface.co/datasets/norygano/causenv) | [Project](https://www.uni-trier.de/universitaet/fachbereiche-faecher/fachbereich-ii/faecher/germanistik/professurenfachteile/germanistische-linguistik/professoren/prof-dr-martin-wengeler/kontroverse-diskurse/individium-gesellschaft)")
st.write("Tags indicators and causes in explicit attributions of causality.")

# Create tabs
tab1, tab2, tab3, tab4, tab5 = st.tabs(["Prompt", "Indicators", "Causes", "Scatter", "Sankey"])

# Prompt Tab
with tab1:
    sentences_input = st.text_area("*Sentences (one per line)*", "\n".join([
        "Autos stehen im Verdacht, Waldsterben zu verursachen.",
        "Fußball führt zu Waldschäden.",
        "Haustüren tragen zum Betonsterben bei.",
    ]), placeholder="German only (currently)")

    sentences = [sentence.strip() for sentence in sentences_input.splitlines() if sentence.strip()]

    if st.button("Analyze"):
        for sentence in sentences:
            inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
            with torch.no_grad():
                outputs = model(**inputs)
            logits = outputs.logits
            predicted_label_ids = torch.argmax(logits, dim=2)
            tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
            predicted_labels = [label_map[label_id.item()] for label_id in predicted_label_ids[0]]

            annotations = []
            current_word = ""
            current_label = "O"
            for token, label in zip(tokens, predicted_labels):
                if token in ['[CLS]', '[SEP]']:  # Exclude special tokens
                    continue
                if token.startswith("##"):
                    current_word += token[2:]
                else:
                    if current_word:
                        if current_label != "O":
                            annotations.append((current_word, current_label))
                        else:
                            annotations.append(current_word)
                        annotations.append(" ")  # Add a space between words
                    current_word = token
                    current_label = label
            if current_word:
                if current_label != "O":
                    annotations.append((current_word, current_label))
                else:
                    annotations.append(current_word)
            st.write(f"**Sentence:** {sentence}")
            annotated_text(*annotations)
            st.write("---")

# Research Insights Tab
with tab2:
    # Overall
    st.subheader("Overall")
    fig_overall = indicator_chart(chart_type='overall')
    st.plotly_chart(fig_overall, use_container_width=True)
    
    # Individual Indicators Chart
    st.subheader("Individual")
    fig_individual = indicator_chart(chart_type='individual')
    st.plotly_chart(fig_individual, use_container_width=True)

with tab3:
    fig_causes = causes_chart()
    st.plotly_chart(fig_causes, use_container_width=True)

with tab4:
    fig_scatter = scatter()
    st.plotly_chart(fig_scatter, use_container_width=True)

with tab5:
    # Fixed height for the Sankey chart container
    with st.container():
        # Retrieve slider values and generate the diagram
        cause_threshold = st.session_state.get("cause_threshold", default_cause_threshold)
        indicator_threshold = st.session_state.get("indicator_threshold", default_indicator_threshold)
        fig_sankey = sankey(cause_threshold=cause_threshold, indicator_threshold=indicator_threshold)
        st.plotly_chart(fig_sankey, use_container_width=True)

    # Place sliders below the chart container
    with st.container():
        cause_threshold = st.slider("Cause >", min_value=1, max_value=100, value=default_cause_threshold, key="cause_threshold")
        indicator_threshold = st.slider("Indicator >", min_value=1, max_value=100, value=default_indicator_threshold, key="indicator_threshold")