File size: 4,244 Bytes
04d4fc6
 
 
4be4c1c
45d0933
 
 
 
04d4fc6
 
 
 
 
 
 
 
60e75a3
04d4fc6
45d0933
 
3a965ad
 
 
 
 
 
45d0933
3a965ad
45d0933
 
04d4fc6
45d0933
 
3a965ad
45d0933
 
 
 
 
 
 
04d4fc6
45d0933
04d4fc6
45d0933
 
 
 
 
 
 
 
 
04d4fc6
45d0933
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04d4fc6
45d0933
 
 
04d4fc6
45d0933
 
 
 
 
 
 
 
 
04d4fc6
45d0933
 
 
 
04d4fc6
45d0933
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from annotated_text import annotated_text
import pandas as pd
import plotly.express as px
from plot import indicator_chart, causes_chart, scatter_plot
import os

# Load the trained model and tokenizer
model_directory = "norygano/causalBERT"
tokenizer = AutoTokenizer.from_pretrained(model_directory, add_prefix_space=True)
model = AutoModelForTokenClassification.from_pretrained(model_directory)
model.eval()

# Define the label map
label_map = {0: "O", 1: "B-INDICATOR", 2: "I-INDICATOR", 3: "B-CAUSE", 4: "I-CAUSE", 5: "B-EFFECT", 6: "I-EFFECT"}

# Main application
st.markdown(
    """
    <div style="display: flex; align-items: center; justify-content: left; font-size: 60px; font-weight: bold;">
        <span>CAUSEN</span>
        <span style="transform: rotate(270deg); display: inline-block; margin-left: 5px;">V</span>
    </div>
    """,
    unsafe_allow_html=True
)
st.markdown("[Model](https://huggingface.co/norygano/causalBERT) | [Data](https://huggingface.co/datasets/norygano/causenv) | [Project](https://www.uni-trier.de/universitaet/fachbereiche-faecher/fachbereich-ii/faecher/germanistik/professurenfachteile/germanistische-linguistik/professoren/prof-dr-martin-wengeler/kontroverse-diskurse/individium-gesellschaft)")
st.write("Tags indicators and causes in explicit attributions of causality. GER only (atm)")

# Create tabs
tab1, tab2, tab3, tab4 = st.tabs(["Prompt", "Indicators", "Causes", "Scatter"])

# Prompt Tab
with tab1:
    sentences_input = st.text_area("*Sentences (one per line)*", "\n".join([
        "Autos stehen im Verdacht, Waldsterben zu verursachen.",
        "Fußball führt zu Waldschäden.",
        "Haustüren tragen zum Betonsterben bei.",
    ]), placeholder="Your Sentences here.")

    sentences = [sentence.strip() for sentence in sentences_input.splitlines() if sentence.strip()]

    if st.button("Analyze"):
        for sentence in sentences:
            inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
            with torch.no_grad():
                outputs = model(**inputs)
            logits = outputs.logits
            predicted_label_ids = torch.argmax(logits, dim=2)
            tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
            predicted_labels = [label_map[label_id.item()] for label_id in predicted_label_ids[0]]

            annotations = []
            current_word = ""
            current_label = "O"
            for token, label in zip(tokens, predicted_labels):
                if token in ['[CLS]', '[SEP]']:  # Exclude special tokens
                    continue
                if token.startswith("##"):
                    current_word += token[2:]
                else:
                    if current_word:
                        if current_label != "O":
                            annotations.append((current_word, current_label))
                        else:
                            annotations.append(current_word)
                        annotations.append(" ")  # Add a space between words
                    current_word = token
                    current_label = label
            if current_word:
                if current_label != "O":
                    annotations.append((current_word, current_label))
                else:
                    annotations.append(current_word)
            st.write(f"**Sentence:** {sentence}")
            annotated_text(*annotations)
            st.write("---")

# Research Insights Tab
with tab2:
    st.write("## Indicators")

    # Overall
    st.subheader("Overall")
    fig_overall = indicator_chart(chart_type='overall')
    st.plotly_chart(fig_overall, use_container_width=True)
    
    # Individual Indicators Chart
    st.subheader("Individual")
    fig_individual = indicator_chart(chart_type='individual')
    st.plotly_chart(fig_individual, use_container_width=True)

with tab3:
    st.write("## Causes")
    fig_causes = causes_chart()
    st.plotly_chart(fig_causes, use_container_width=True)

with tab4:
    st.write("## Scatter")
    fig_scatter = scatter_plot()
    st.plotly_chart(fig_scatter, use_container_width=True)