File size: 3,169 Bytes
96be409
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification

# Load the trained model and tokenizer
model_directory = "norygano/causalBERT"
tokenizer = AutoTokenizer.from_pretrained(model_directory, add_prefix_space=True)
model = AutoModelForTokenClassification.from_pretrained(model_directory)

# Set model to evaluation mode
model.eval()

# Define the label map
label_map = {0: "O", 1: "B-INDICATOR", 2: "I-INDICATOR", 3: "B-CAUSE", 4: "I-CAUSE"}

# Streamlit App
st.title("Attribution of Causality")
st.write("Tags indicators and causes. GER only (for now)")

# Text input for sentences
sentences_input = st.text_area("Sentences (one per line)", "\n".join([
    "Laub könnte verantwortlich für den Klimawandel sein.",
    "Nach dem Verursachergrundsatz spielt das keine Rolle.",
    #"Backenzähne verursachen Artensterben.",
    "Fußball führt zu Waldschäden.",
    #"Das hängt mit vielen Faktoren zusammen.",
    "Haustüren tragen zum Betonsterben bei.",
    #"Autos stehen im verdacht, Bienensterben auszulösen.",
    #"Lösen Straßen Waldsterben aus?"
]))

# Split the input text into individual sentences
sentences = [sentence.strip() for sentence in sentences_input.splitlines() if sentence.strip()]

# Button to run the model
if st.button("Analyze Sentences"):
    for sentence in sentences:
        # Tokenize the sentence
        inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)

        # Run inference
        with torch.no_grad():
            outputs = model(**inputs)
        
        # Get the logits and predicted label IDs
        logits = outputs.logits
        predicted_label_ids = torch.argmax(logits, dim=2)

        # Convert token IDs back to tokens
        tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

        # Map label IDs to human-readable labels
        predicted_labels = [label_map[label_id.item()] for label_id in predicted_label_ids[0]]

        # Reconstruct words from subwords
        reconstructed_tokens = []
        reconstructed_labels = []
        for token, label in zip(tokens, predicted_labels):
            if token in ['[CLS]', '[SEP]']:  # Exclude special tokens
                continue
            if token.startswith("##"):
                reconstructed_tokens[-1] += token[2:]  # Append subword
            else:
                reconstructed_tokens.append(token)
                reconstructed_labels.append(label)

        # Format output with square brackets
        formatted_output = []
        for token, label in zip(reconstructed_tokens, reconstructed_labels):
            if label != "O":
                # Use square brackets around label names
                formatted_output.append(f"[{label}] <b>{token}</b> [/{label}]")
            else:
                formatted_output.append(token)

        # Join tokens for display
        output_sentence = " ".join(formatted_output)

        # Display formatted sentence with Streamlit
        st.write(f"**Original Sentence:** {sentence}")
        st.markdown(f"**Labeled Output:** {output_sentence}", unsafe_allow_html=True)
        st.write("---")