causev / causenv.py
norygano's picture
Initialize repository with .gitattributes, README.md, and causenv.py
96be409
raw
history blame
3.17 kB
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
# Load the trained model and tokenizer
model_directory = "norygano/causalBERT"
tokenizer = AutoTokenizer.from_pretrained(model_directory, add_prefix_space=True)
model = AutoModelForTokenClassification.from_pretrained(model_directory)
# Set model to evaluation mode
model.eval()
# Define the label map
label_map = {0: "O", 1: "B-INDICATOR", 2: "I-INDICATOR", 3: "B-CAUSE", 4: "I-CAUSE"}
# Streamlit App
st.title("Attribution of Causality")
st.write("Tags indicators and causes. GER only (for now)")
# Text input for sentences
sentences_input = st.text_area("Sentences (one per line)", "\n".join([
"Laub könnte verantwortlich für den Klimawandel sein.",
"Nach dem Verursachergrundsatz spielt das keine Rolle.",
#"Backenzähne verursachen Artensterben.",
"Fußball führt zu Waldschäden.",
#"Das hängt mit vielen Faktoren zusammen.",
"Haustüren tragen zum Betonsterben bei.",
#"Autos stehen im verdacht, Bienensterben auszulösen.",
#"Lösen Straßen Waldsterben aus?"
]))
# Split the input text into individual sentences
sentences = [sentence.strip() for sentence in sentences_input.splitlines() if sentence.strip()]
# Button to run the model
if st.button("Analyze Sentences"):
for sentence in sentences:
# Tokenize the sentence
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
# Run inference
with torch.no_grad():
outputs = model(**inputs)
# Get the logits and predicted label IDs
logits = outputs.logits
predicted_label_ids = torch.argmax(logits, dim=2)
# Convert token IDs back to tokens
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
# Map label IDs to human-readable labels
predicted_labels = [label_map[label_id.item()] for label_id in predicted_label_ids[0]]
# Reconstruct words from subwords
reconstructed_tokens = []
reconstructed_labels = []
for token, label in zip(tokens, predicted_labels):
if token in ['[CLS]', '[SEP]']: # Exclude special tokens
continue
if token.startswith("##"):
reconstructed_tokens[-1] += token[2:] # Append subword
else:
reconstructed_tokens.append(token)
reconstructed_labels.append(label)
# Format output with square brackets
formatted_output = []
for token, label in zip(reconstructed_tokens, reconstructed_labels):
if label != "O":
# Use square brackets around label names
formatted_output.append(f"[{label}] <b>{token}</b> [/{label}]")
else:
formatted_output.append(token)
# Join tokens for display
output_sentence = " ".join(formatted_output)
# Display formatted sentence with Streamlit
st.write(f"**Original Sentence:** {sentence}")
st.markdown(f"**Labeled Output:** {output_sentence}", unsafe_allow_html=True)
st.write("---")