causev / app.py
norygano's picture
Kolloquium
adb4a34
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from annotated_text import annotated_text
from plot import Plot
from flow import FlowChart
import os
# Define initial threshold values at the top of the script
default_cause_threshold = 25
default_indicator_threshold = 15
default_cause_threshold_sankey = 20
default_indicator_threshold_sankey = 15
# Initialize Plots
plot = Plot()
flow_chart = FlowChart()
# Load the trained model and tokenizer
model_directory = "norygano/causalBERT"
tokenizer = AutoTokenizer.from_pretrained(model_directory, add_prefix_space=True)
model = AutoModelForTokenClassification.from_pretrained(model_directory)
model.eval()
# Define the label map
label_map = {0: "O", 1: "B-INDICATOR", 2: "I-INDICATOR", 3: "B-CAUSE", 4: "I-CAUSE", 5: "B-EFFECT", 6: "I-EFFECT"}
# Main application
st.markdown(
"""
<div style="display: flex; align-items: center; justify-content: left; font-size: 60px; font-weight: bold;">
<span>CAUSE</span>
<span style="transform: rotate(270deg); display: inline-block; margin-left: 5px;">V</span>
</div>
""",
unsafe_allow_html=True
)
st.markdown("[Weights](https://huggingface.co/norygano/causalBERT) | [Data](https://huggingface.co/datasets/norygano/causenv)")
st.write("Indicators and causes in explicit attributions of causality.")
# Create tabs
tab1, tab2, tab3, tab4, tab5 = st.tabs(["Prompt", "Indicators", "Causes", "Scatter", "Sankey"])
# Prompt Tab
with tab1:
sentences_input = st.text_area("*Sentences (one per line)*", "\n".join([
"Autos stehen im Verdacht, Waldsterben verursacht zu haben.",
"Fußball führt zu Waldschäden.",
"Haustüren tragen zum Betonsterben bei.",
]), placeholder="German only (currently)")
sentences = [sentence.strip() for sentence in sentences_input.splitlines() if sentence.strip()]
if st.button("Analyze"):
for sentence in sentences:
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_label_ids = torch.argmax(logits, dim=2)
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
predicted_labels = [label_map[label_id.item()] for label_id in predicted_label_ids[0]]
annotations = []
current_word = ""
current_label = "O"
for token, label in zip(tokens, predicted_labels):
if token in ['[CLS]', '[SEP]']: # Exclude special tokens
continue
if token.startswith("##"):
current_word += token[2:]
else:
if current_word:
if current_label != "O":
annotations.append((current_word, current_label))
else:
annotations.append(current_word)
annotations.append(" ") # Add a space between words
current_word = token
current_label = label
if current_word:
if current_label != "O":
annotations.append((current_word, current_label))
else:
annotations.append(current_word)
st.write(f"**Sentence:** {sentence}")
annotated_text(*annotations)
st.write("---")
# Indicator Tab
with tab2:
selected_chart_type = st.radio(
label="Type",
label_visibility='collapsed',
options=['Total', 'Year', 'Individual'],
horizontal=True,
)
# Display the chart in a container
with st.container():
if selected_chart_type == 'Individual':
# Retrieve slider value from session state or use default
individual_threshold = st.session_state.get("individual_threshold", default_indicator_threshold)
fig = plot.get_indicator_chart(chart_type=selected_chart_type.lower(), individual_threshold=individual_threshold)
else:
fig = plot.get_indicator_chart(chart_type=selected_chart_type.lower())
st.plotly_chart(fig, use_container_width=True)
# Display the slider below the chart container for 'Individual' type
if selected_chart_type == 'Individual':
with st.container():
individual_threshold = st.slider(
"Indicator >=",
min_value=1,
max_value=95,
value=default_indicator_threshold,
key="individual_threshold"
)
# Causes Tab
with tab3:
# Create a container for the chart and place the slider below it
with st.container():
# Display the chart first
fig_causes = plot.get_causes_chart(min_value=st.session_state.get("cause_threshold_causes", default_cause_threshold))
st.plotly_chart(fig_causes, use_container_width=True)
# Place the slider below the chart with a unique key
cause_threshold_causes = st.slider(
"Cause >=", min_value=1, max_value=75, value=default_cause_threshold, key="cause_threshold_causes"
)
# Scatter Tab
with tab4:
fig_scatter = plot.scatter()
st.plotly_chart(fig_scatter, use_container_width=True)
# Sankey Tab
with tab5:
with st.container():
# Use the unique Sankey threshold variables in session state
cause_threshold_sankey = st.session_state.get("cause_threshold_sankey", default_cause_threshold_sankey)
indicator_threshold_sankey = st.session_state.get("indicator_threshold_sankey", default_indicator_threshold_sankey)
# Generate the Sankey diagram with the new Sankey-specific thresholds
fig_sankey = plot.sankey(cause_threshold=cause_threshold_sankey, indicator_threshold=indicator_threshold_sankey)
st.plotly_chart(fig_sankey, use_container_width=True)
# Place sliders below the chart container with unique keys for the Sankey tab
with st.container():
cause_threshold_sankey = st.slider(
"Cause >=", min_value=1, max_value=100, value=default_cause_threshold_sankey, key="cause_threshold_sankey"
)
indicator_threshold_sankey = st.slider(
"Indicator >=", min_value=1, max_value=100, value=default_indicator_threshold_sankey, key="indicator_threshold_sankey"
)