File size: 3,649 Bytes
04d4fc6
 
 
4be4c1c
04d4fc6
 
 
 
 
 
 
 
 
 
 
 
 
3a965ad
 
 
 
 
 
 
 
 
04d4fc6
3a965ad
 
 
 
 
04d4fc6
 
 
3a965ad
 
 
04d4fc6
 
 
 
 
3a965ad
04d4fc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4be4c1c
 
3a965ad
 
 
04d4fc6
 
 
3a965ad
04d4fc6
3a965ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04d4fc6
3a965ad
04d4fc6
4be4c1c
7fdb2a4
4be4c1c
04d4fc6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from annotated_text import annotated_text

# Load the trained model and tokenizer
model_directory = "norygano/causalBERT"
tokenizer = AutoTokenizer.from_pretrained(model_directory, add_prefix_space=True)
model = AutoModelForTokenClassification.from_pretrained(model_directory)

# Set model to evaluation mode
model.eval()

# Define the label map
label_map = {0: "O", 1: "B-INDICATOR", 2: "I-INDICATOR", 3: "B-CAUSE", 4: "I-CAUSE"}

# Streamlit App
st.markdown(
    """
    <div style="display: flex; align-items: center; justify-content: left; font-size: 60px; font-weight: bold;">
        <span>CAUSEN</span>
        <span style="transform: rotate(270deg); display: inline-block; margin-left: 5px;">V</span>
    </div>
    """,
    unsafe_allow_html=True
)

# Add a description with a link to the model
st.write("Tags indicators and causes in explicit attribution of causality. GER only (atm)")

# Text input for sentences with italic placeholder text
sentences_input = st.text_area("*Sentences (one per line)*", "\n".join([
    "Laub könnte verantwortlich für den Klimawandel sein.",
    "Fußball führt zu Waldschäden.",
    "Haustüren tragen zum Betonsterben bei.",
])
, placeholder="Your Sentences here.")
st.markdown("[Model](https://huggingface.co/norygano/causalBERT)")

# Split the input text into individual sentences
sentences = [sentence.strip() for sentence in sentences_input.splitlines() if sentence.strip()]

# Button to run the model
if st.button("Analyze"):
    for sentence in sentences:
        # Tokenize the sentence
        inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)

        # Run inference
        with torch.no_grad():
            outputs = model(**inputs)
        
        # Get the logits and predicted label IDs
        logits = outputs.logits
        predicted_label_ids = torch.argmax(logits, dim=2)

        # Convert token IDs back to tokens
        tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

        # Map label IDs to human-readable labels
        predicted_labels = [label_map[label_id.item()] for label_id in predicted_label_ids[0]]

        # Reconstruct words from subwords and prepare for annotated_text
        annotations = []
        current_word = ""
        current_label = "O"
        
        for token, label in zip(tokens, predicted_labels):
            if token in ['[CLS]', '[SEP]']:  # Exclude special tokens
                continue
            
            if token.startswith("##"):
                # Append subword without "##" prefix to the current word
                current_word += token[2:]
            else:
                # If we have accumulated a word, add it to annotations with a space
                if current_word:
                    if current_label != "O":
                        annotations.append((current_word, current_label))
                    else:
                        annotations.append(current_word)
                    annotations.append(" ")  # Add a space between words
                
                # Start a new word
                current_word = token
                current_label = label
        
        # Add the last accumulated word
        if current_word:
            if current_label != "O":
                annotations.append((current_word, current_label))
            else:
                annotations.append(current_word)

        # Display annotated text
        st.write(f"**Sentence:** {sentence}")
        annotated_text(*annotations)
        st.write("---")