import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from annotated_text import annotated_text
import pandas as pd
import plotly.express as px
from plot import indicator_chart, causes_chart, scatter_plot
import os
# Load the trained model and tokenizer
model_directory = "norygano/causalBERT"
tokenizer = AutoTokenizer.from_pretrained(model_directory, add_prefix_space=True)
model = AutoModelForTokenClassification.from_pretrained(model_directory)
model.eval()
# Define the label map
label_map = {0: "O", 1: "B-INDICATOR", 2: "I-INDICATOR", 3: "B-CAUSE", 4: "I-CAUSE", 5: "B-EFFECT", 6: "I-EFFECT"}
# Main application
st.markdown(
"""
CAUSEN
V
""",
unsafe_allow_html=True
)
st.markdown("[Model](https://huggingface.co/norygano/causalBERT) | [Data](https://huggingface.co/datasets/norygano/causenv) | [Project](https://www.uni-trier.de/universitaet/fachbereiche-faecher/fachbereich-ii/faecher/germanistik/professurenfachteile/germanistische-linguistik/professoren/prof-dr-martin-wengeler/kontroverse-diskurse/individium-gesellschaft)")
st.write("Tags indicators and causes in explicit attributions of causality. GER only (atm)")
# Create tabs
tab1, tab2, tab3, tab4 = st.tabs(["Prompt", "Indicators", "Causes", "Scatter"])
# Prompt Tab
with tab1:
sentences_input = st.text_area("*Sentences (one per line)*", "\n".join([
"Autos stehen im Verdacht, Waldsterben zu verursachen.",
"Fußball führt zu Waldschäden.",
"Haustüren tragen zum Betonsterben bei.",
]), placeholder="Your Sentences here.")
sentences = [sentence.strip() for sentence in sentences_input.splitlines() if sentence.strip()]
if st.button("Analyze"):
for sentence in sentences:
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_label_ids = torch.argmax(logits, dim=2)
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
predicted_labels = [label_map[label_id.item()] for label_id in predicted_label_ids[0]]
annotations = []
current_word = ""
current_label = "O"
for token, label in zip(tokens, predicted_labels):
if token in ['[CLS]', '[SEP]']: # Exclude special tokens
continue
if token.startswith("##"):
current_word += token[2:]
else:
if current_word:
if current_label != "O":
annotations.append((current_word, current_label))
else:
annotations.append(current_word)
annotations.append(" ") # Add a space between words
current_word = token
current_label = label
if current_word:
if current_label != "O":
annotations.append((current_word, current_label))
else:
annotations.append(current_word)
st.write(f"**Sentence:** {sentence}")
annotated_text(*annotations)
st.write("---")
# Research Insights Tab
with tab2:
st.write("## Indicators")
# Overall
st.subheader("Overall")
fig_overall = indicator_chart(chart_type='overall')
st.plotly_chart(fig_overall, use_container_width=True)
# Individual Indicators Chart
st.subheader("Individual")
fig_individual = indicator_chart(chart_type='individual')
st.plotly_chart(fig_individual, use_container_width=True)
with tab3:
st.write("## Causes")
fig_causes = causes_chart()
st.plotly_chart(fig_causes, use_container_width=True)
with tab4:
st.write("## Scatter")
fig_scatter = scatter_plot()
st.plotly_chart(fig_scatter, use_container_width=True)