File size: 3,683 Bytes
04d4fc6
 
 
4be4c1c
04d4fc6
 
 
 
 
 
 
 
 
 
60e75a3
04d4fc6
 
c5edc44
3a965ad
 
 
 
 
 
c2829c2
3a965ad
7932599
04d4fc6
3a965ad
7932599
3a965ad
 
 
7932599
04d4fc6
 
3a965ad
7932599
04d4fc6
 
 
 
 
3a965ad
04d4fc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4be4c1c
 
3a965ad
 
 
04d4fc6
 
 
3a965ad
04d4fc6
3a965ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04d4fc6
3a965ad
04d4fc6
4be4c1c
7fdb2a4
4be4c1c
04d4fc6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from annotated_text import annotated_text

# Load the trained model and tokenizer
model_directory = "norygano/causalBERT"
tokenizer = AutoTokenizer.from_pretrained(model_directory, add_prefix_space=True)
model = AutoModelForTokenClassification.from_pretrained(model_directory)

# Set model to evaluation mode
model.eval()

# Define the label map
label_map = {0: "O", 1: "B-INDICATOR", 2: "I-INDICATOR", 3: "B-CAUSE", 4: "I-CAUSE", 5: "B-EFFECT", 6: "I-EFFECT"}

# Streamlit App
st.markdown( 
    """
    <div style="display: flex; align-items: center; justify-content: left; font-size: 60px; font-weight: bold;">
        <span>CAUSEN</span>
        <span style="transform: rotate(270deg); display: inline-block; margin-left: 5px;">V</span>
    </div>
    """,
    unsafe_allow_html=True 
)
st.markdown("[Model](https://huggingface.co/norygano/causalBERT)")

# Add a description with a link to the model
st.write("Tags indicators and causes of explicit attributions of causality. GER only (atm)")

# Text input for sentences with italic placeholder text
sentences_input = st.text_area("*Sentences (one per line)*", "\n".join([
    "Autos stehen im Verdacht, Waldsterben zu verursachen.",
    "Fußball führt zu Waldschäden.",
    "Haustüren tragen zum Betonsterben bei.",
])
  , placeholder="Your Sentences here.")

# Split the input text into individual sentences
sentences = [sentence.strip() for sentence in sentences_input.splitlines() if sentence.strip()]

# Button to run the model
if st.button("Analyze"):
    for sentence in sentences:
        # Tokenize the sentence
        inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)

        # Run inference
        with torch.no_grad():
            outputs = model(**inputs)
        
        # Get the logits and predicted label IDs
        logits = outputs.logits
        predicted_label_ids = torch.argmax(logits, dim=2)

        # Convert token IDs back to tokens
        tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

        # Map label IDs to human-readable labels
        predicted_labels = [label_map[label_id.item()] for label_id in predicted_label_ids[0]]

        # Reconstruct words from subwords and prepare for annotated_text
        annotations = []
        current_word = ""
        current_label = "O"
        
        for token, label in zip(tokens, predicted_labels):
            if token in ['[CLS]', '[SEP]']:  # Exclude special tokens
                continue
            
            if token.startswith("##"):
                # Append subword without "##" prefix to the current word
                current_word += token[2:]
            else:
                # If we have accumulated a word, add it to annotations with a space
                if current_word:
                    if current_label != "O":
                        annotations.append((current_word, current_label))
                    else:
                        annotations.append(current_word)
                    annotations.append(" ")  # Add a space between words
                
                # Start a new word
                current_word = token
                current_label = label
        
        # Add the last accumulated word
        if current_word:
            if current_label != "O":
                annotations.append((current_word, current_label))
            else:
                annotations.append(current_word)

        # Display annotated text
        st.write(f"**Sentence:** {sentence}")
        annotated_text(*annotations)
        st.write("---")