|
import streamlit as st |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForTokenClassification |
|
|
|
|
|
model_directory = "norygano/causalBERT" |
|
tokenizer = AutoTokenizer.from_pretrained(model_directory, add_prefix_space=True) |
|
model = AutoModelForTokenClassification.from_pretrained(model_directory) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
label_map = {0: "O", 1: "B-INDICATOR", 2: "I-INDICATOR", 3: "B-CAUSE", 4: "I-CAUSE"} |
|
|
|
|
|
st.title("Attribution of Causality") |
|
st.write("Tags indicators and causes. GER only (for now)") |
|
|
|
|
|
sentences_input = st.text_area("Sentences (one per line)", "\n".join([ |
|
"Laub könnte verantwortlich für den Klimawandel sein.", |
|
"Nach dem Verursachergrundsatz spielt das keine Rolle.", |
|
|
|
"Fußball führt zu Waldschäden.", |
|
|
|
"Haustüren tragen zum Betonsterben bei.", |
|
|
|
|
|
])) |
|
|
|
|
|
sentences = [sentence.strip() for sentence in sentences_input.splitlines() if sentence.strip()] |
|
|
|
|
|
if st.button("Analyze Sentences"): |
|
for sentence in sentences: |
|
|
|
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
|
|
logits = outputs.logits |
|
predicted_label_ids = torch.argmax(logits, dim=2) |
|
|
|
|
|
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) |
|
|
|
|
|
predicted_labels = [label_map[label_id.item()] for label_id in predicted_label_ids[0]] |
|
|
|
|
|
reconstructed_tokens = [] |
|
reconstructed_labels = [] |
|
for token, label in zip(tokens, predicted_labels): |
|
if token in ['[CLS]', '[SEP]']: |
|
continue |
|
if token.startswith("##"): |
|
reconstructed_tokens[-1] += token[2:] |
|
else: |
|
reconstructed_tokens.append(token) |
|
reconstructed_labels.append(label) |
|
|
|
|
|
formatted_output = [] |
|
for token, label in zip(reconstructed_tokens, reconstructed_labels): |
|
if label != "O": |
|
|
|
formatted_output.append(f"[{label}] <b>{token}</b> [/{label}]") |
|
else: |
|
formatted_output.append(token) |
|
|
|
|
|
output_sentence = " ".join(formatted_output) |
|
|
|
|
|
st.write(f"**Original Sentence:** {sentence}") |
|
st.markdown(f"**Labeled Output:** {output_sentence}", unsafe_allow_html=True) |
|
st.write("---") |
|
|