File size: 3,683 Bytes
04d4fc6 4be4c1c 04d4fc6 60e75a3 04d4fc6 c5edc44 3a965ad c2829c2 3a965ad 7932599 04d4fc6 3a965ad 7932599 3a965ad 7932599 04d4fc6 3a965ad 7932599 04d4fc6 3a965ad 04d4fc6 4be4c1c 3a965ad 04d4fc6 3a965ad 04d4fc6 3a965ad 04d4fc6 3a965ad 04d4fc6 4be4c1c 7fdb2a4 4be4c1c 04d4fc6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from annotated_text import annotated_text
# Load the trained model and tokenizer
model_directory = "norygano/causalBERT"
tokenizer = AutoTokenizer.from_pretrained(model_directory, add_prefix_space=True)
model = AutoModelForTokenClassification.from_pretrained(model_directory)
# Set model to evaluation mode
model.eval()
# Define the label map
label_map = {0: "O", 1: "B-INDICATOR", 2: "I-INDICATOR", 3: "B-CAUSE", 4: "I-CAUSE", 5: "B-EFFECT", 6: "I-EFFECT"}
# Streamlit App
st.markdown(
"""
<div style="display: flex; align-items: center; justify-content: left; font-size: 60px; font-weight: bold;">
<span>CAUSEN</span>
<span style="transform: rotate(270deg); display: inline-block; margin-left: 5px;">V</span>
</div>
""",
unsafe_allow_html=True
)
st.markdown("[Model](https://huggingface.co/norygano/causalBERT)")
# Add a description with a link to the model
st.write("Tags indicators and causes of explicit attributions of causality. GER only (atm)")
# Text input for sentences with italic placeholder text
sentences_input = st.text_area("*Sentences (one per line)*", "\n".join([
"Autos stehen im Verdacht, Waldsterben zu verursachen.",
"Fußball führt zu Waldschäden.",
"Haustüren tragen zum Betonsterben bei.",
])
, placeholder="Your Sentences here.")
# Split the input text into individual sentences
sentences = [sentence.strip() for sentence in sentences_input.splitlines() if sentence.strip()]
# Button to run the model
if st.button("Analyze"):
for sentence in sentences:
# Tokenize the sentence
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
# Run inference
with torch.no_grad():
outputs = model(**inputs)
# Get the logits and predicted label IDs
logits = outputs.logits
predicted_label_ids = torch.argmax(logits, dim=2)
# Convert token IDs back to tokens
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
# Map label IDs to human-readable labels
predicted_labels = [label_map[label_id.item()] for label_id in predicted_label_ids[0]]
# Reconstruct words from subwords and prepare for annotated_text
annotations = []
current_word = ""
current_label = "O"
for token, label in zip(tokens, predicted_labels):
if token in ['[CLS]', '[SEP]']: # Exclude special tokens
continue
if token.startswith("##"):
# Append subword without "##" prefix to the current word
current_word += token[2:]
else:
# If we have accumulated a word, add it to annotations with a space
if current_word:
if current_label != "O":
annotations.append((current_word, current_label))
else:
annotations.append(current_word)
annotations.append(" ") # Add a space between words
# Start a new word
current_word = token
current_label = label
# Add the last accumulated word
if current_word:
if current_label != "O":
annotations.append((current_word, current_label))
else:
annotations.append(current_word)
# Display annotated text
st.write(f"**Sentence:** {sentence}")
annotated_text(*annotations)
st.write("---")
|