|
import streamlit as st |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForTokenClassification |
|
from annotated_text import annotated_text |
|
from plot import Plot |
|
from flow import FlowChart |
|
import os |
|
|
|
|
|
default_cause_threshold = 25 |
|
default_indicator_threshold = 15 |
|
default_cause_threshold_sankey = 20 |
|
default_indicator_threshold_sankey = 15 |
|
|
|
|
|
plot = Plot() |
|
flow_chart = FlowChart() |
|
|
|
|
|
model_directory = "norygano/causalBERT" |
|
tokenizer = AutoTokenizer.from_pretrained(model_directory, add_prefix_space=True) |
|
model = AutoModelForTokenClassification.from_pretrained(model_directory) |
|
model.eval() |
|
|
|
|
|
label_map = {0: "O", 1: "B-INDICATOR", 2: "I-INDICATOR", 3: "B-CAUSE", 4: "I-CAUSE", 5: "B-EFFECT", 6: "I-EFFECT"} |
|
|
|
|
|
st.markdown( |
|
""" |
|
<div style="display: flex; align-items: center; justify-content: left; font-size: 60px; font-weight: bold;"> |
|
<span>CAUSE</span> |
|
<span style="transform: rotate(270deg); display: inline-block; margin-left: 5px;">V</span> |
|
</div> |
|
""", |
|
unsafe_allow_html=True |
|
) |
|
st.markdown("[Weights](https://huggingface.co/norygano/causalBERT) | [Data](https://huggingface.co/datasets/norygano/causenv)") |
|
st.write("Indicators and causes in explicit attributions of causality.") |
|
|
|
|
|
tab1, tab2, tab3, tab4, tab5 = st.tabs(["Prompt", "Indicators", "Causes", "Scatter", "Sankey"]) |
|
|
|
|
|
with tab1: |
|
sentences_input = st.text_area("*Sentences (one per line)*", "\n".join([ |
|
"Autos stehen im Verdacht, Waldsterben verursacht zu haben.", |
|
"Fußball führt zu Waldschäden.", |
|
"Haustüren tragen zum Betonsterben bei.", |
|
]), placeholder="German only (currently)") |
|
|
|
sentences = [sentence.strip() for sentence in sentences_input.splitlines() if sentence.strip()] |
|
|
|
if st.button("Analyze"): |
|
for sentence in sentences: |
|
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
predicted_label_ids = torch.argmax(logits, dim=2) |
|
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) |
|
predicted_labels = [label_map[label_id.item()] for label_id in predicted_label_ids[0]] |
|
|
|
annotations = [] |
|
current_word = "" |
|
current_label = "O" |
|
for token, label in zip(tokens, predicted_labels): |
|
if token in ['[CLS]', '[SEP]']: |
|
continue |
|
if token.startswith("##"): |
|
current_word += token[2:] |
|
else: |
|
if current_word: |
|
if current_label != "O": |
|
annotations.append((current_word, current_label)) |
|
else: |
|
annotations.append(current_word) |
|
annotations.append(" ") |
|
current_word = token |
|
current_label = label |
|
if current_word: |
|
if current_label != "O": |
|
annotations.append((current_word, current_label)) |
|
else: |
|
annotations.append(current_word) |
|
st.write(f"**Sentence:** {sentence}") |
|
annotated_text(*annotations) |
|
st.write("---") |
|
|
|
|
|
|
|
with tab2: |
|
selected_chart_type = st.radio( |
|
label="Type", |
|
label_visibility='collapsed', |
|
options=['Total', 'Year', 'Individual'], |
|
horizontal=True, |
|
) |
|
|
|
|
|
with st.container(): |
|
if selected_chart_type == 'Individual': |
|
|
|
individual_threshold = st.session_state.get("individual_threshold", default_indicator_threshold) |
|
fig = plot.get_indicator_chart(chart_type=selected_chart_type.lower(), individual_threshold=individual_threshold) |
|
else: |
|
fig = plot.get_indicator_chart(chart_type=selected_chart_type.lower()) |
|
st.plotly_chart(fig, use_container_width=True) |
|
|
|
|
|
if selected_chart_type == 'Individual': |
|
with st.container(): |
|
individual_threshold = st.slider( |
|
"Indicator >=", |
|
min_value=1, |
|
max_value=95, |
|
value=default_indicator_threshold, |
|
key="individual_threshold" |
|
) |
|
|
|
|
|
with tab3: |
|
|
|
with st.container(): |
|
|
|
fig_causes = plot.get_causes_chart(min_value=st.session_state.get("cause_threshold_causes", default_cause_threshold)) |
|
st.plotly_chart(fig_causes, use_container_width=True) |
|
|
|
|
|
cause_threshold_causes = st.slider( |
|
"Cause >=", min_value=1, max_value=75, value=default_cause_threshold, key="cause_threshold_causes" |
|
) |
|
|
|
|
|
with tab4: |
|
fig_scatter = plot.scatter() |
|
st.plotly_chart(fig_scatter, use_container_width=True) |
|
|
|
|
|
with tab5: |
|
with st.container(): |
|
|
|
cause_threshold_sankey = st.session_state.get("cause_threshold_sankey", default_cause_threshold_sankey) |
|
indicator_threshold_sankey = st.session_state.get("indicator_threshold_sankey", default_indicator_threshold_sankey) |
|
|
|
|
|
fig_sankey = plot.sankey(cause_threshold=cause_threshold_sankey, indicator_threshold=indicator_threshold_sankey) |
|
st.plotly_chart(fig_sankey, use_container_width=True) |
|
|
|
with st.container(): |
|
cause_threshold_sankey = st.slider( |
|
"Cause >=", min_value=1, max_value=100, value=default_cause_threshold_sankey, key="cause_threshold_sankey" |
|
) |
|
indicator_threshold_sankey = st.slider( |
|
"Indicator >=", min_value=1, max_value=100, value=default_indicator_threshold_sankey, key="indicator_threshold_sankey" |
|
) |