File size: 6,495 Bytes
04d4fc6 4be4c1c adb4a34 45d0933 04d4fc6 9525dec adb4a34 6b22889 7e996cc adb4a34 9525dec 04d4fc6 60e75a3 04d4fc6 45d0933 3a965ad 2e68043 3a965ad 45d0933 3a965ad adb4a34 7e996cc 04d4fc6 45d0933 632a488 3a965ad 45d0933 7e996cc 45d0933 6b22889 04d4fc6 45d0933 04d4fc6 45d0933 04d4fc6 45d0933 04d4fc6 7e996cc 45d0933 7e996cc 632a488 7e996cc 04d4fc6 7e996cc 45d0933 7e996cc 04d4fc6 7e996cc 45d0933 7e996cc 9525dec 7e996cc 9525dec 7e996cc 9525dec 7e996cc 9525dec 7e996cc adb4a34 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from annotated_text import annotated_text
from plot import Plot
from flow import FlowChart
import os
# Define initial threshold values at the top of the script
default_cause_threshold = 25
default_indicator_threshold = 15
default_cause_threshold_sankey = 20
default_indicator_threshold_sankey = 15
# Initialize Plots
plot = Plot()
flow_chart = FlowChart()
# Load the trained model and tokenizer
model_directory = "norygano/causalBERT"
tokenizer = AutoTokenizer.from_pretrained(model_directory, add_prefix_space=True)
model = AutoModelForTokenClassification.from_pretrained(model_directory)
model.eval()
# Define the label map
label_map = {0: "O", 1: "B-INDICATOR", 2: "I-INDICATOR", 3: "B-CAUSE", 4: "I-CAUSE", 5: "B-EFFECT", 6: "I-EFFECT"}
# Main application
st.markdown(
"""
<div style="display: flex; align-items: center; justify-content: left; font-size: 60px; font-weight: bold;">
<span>CAUSE</span>
<span style="transform: rotate(270deg); display: inline-block; margin-left: 5px;">V</span>
</div>
""",
unsafe_allow_html=True
)
st.markdown("[Weights](https://huggingface.co/norygano/causalBERT) | [Data](https://huggingface.co/datasets/norygano/causenv)")
st.write("Indicators and causes in explicit attributions of causality.")
# Create tabs
tab1, tab2, tab3, tab4, tab5 = st.tabs(["Prompt", "Indicators", "Causes", "Scatter", "Sankey"])
# Prompt Tab
with tab1:
sentences_input = st.text_area("*Sentences (one per line)*", "\n".join([
"Autos stehen im Verdacht, Waldsterben verursacht zu haben.",
"Fußball führt zu Waldschäden.",
"Haustüren tragen zum Betonsterben bei.",
]), placeholder="German only (currently)")
sentences = [sentence.strip() for sentence in sentences_input.splitlines() if sentence.strip()]
if st.button("Analyze"):
for sentence in sentences:
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_label_ids = torch.argmax(logits, dim=2)
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
predicted_labels = [label_map[label_id.item()] for label_id in predicted_label_ids[0]]
annotations = []
current_word = ""
current_label = "O"
for token, label in zip(tokens, predicted_labels):
if token in ['[CLS]', '[SEP]']: # Exclude special tokens
continue
if token.startswith("##"):
current_word += token[2:]
else:
if current_word:
if current_label != "O":
annotations.append((current_word, current_label))
else:
annotations.append(current_word)
annotations.append(" ") # Add a space between words
current_word = token
current_label = label
if current_word:
if current_label != "O":
annotations.append((current_word, current_label))
else:
annotations.append(current_word)
st.write(f"**Sentence:** {sentence}")
annotated_text(*annotations)
st.write("---")
# Indicator Tab
with tab2:
selected_chart_type = st.radio(
label="Type",
label_visibility='collapsed',
options=['Total', 'Year', 'Individual'],
horizontal=True,
)
# Display the chart in a container
with st.container():
if selected_chart_type == 'Individual':
# Retrieve slider value from session state or use default
individual_threshold = st.session_state.get("individual_threshold", default_indicator_threshold)
fig = plot.get_indicator_chart(chart_type=selected_chart_type.lower(), individual_threshold=individual_threshold)
else:
fig = plot.get_indicator_chart(chart_type=selected_chart_type.lower())
st.plotly_chart(fig, use_container_width=True)
# Display the slider below the chart container for 'Individual' type
if selected_chart_type == 'Individual':
with st.container():
individual_threshold = st.slider(
"Indicator >=",
min_value=1,
max_value=95,
value=default_indicator_threshold,
key="individual_threshold"
)
# Causes Tab
with tab3:
# Create a container for the chart and place the slider below it
with st.container():
# Display the chart first
fig_causes = plot.get_causes_chart(min_value=st.session_state.get("cause_threshold_causes", default_cause_threshold))
st.plotly_chart(fig_causes, use_container_width=True)
# Place the slider below the chart with a unique key
cause_threshold_causes = st.slider(
"Cause >=", min_value=1, max_value=75, value=default_cause_threshold, key="cause_threshold_causes"
)
# Scatter Tab
with tab4:
fig_scatter = plot.scatter()
st.plotly_chart(fig_scatter, use_container_width=True)
# Sankey Tab
with tab5:
with st.container():
# Use the unique Sankey threshold variables in session state
cause_threshold_sankey = st.session_state.get("cause_threshold_sankey", default_cause_threshold_sankey)
indicator_threshold_sankey = st.session_state.get("indicator_threshold_sankey", default_indicator_threshold_sankey)
# Generate the Sankey diagram with the new Sankey-specific thresholds
fig_sankey = plot.sankey(cause_threshold=cause_threshold_sankey, indicator_threshold=indicator_threshold_sankey)
st.plotly_chart(fig_sankey, use_container_width=True)
# Place sliders below the chart container with unique keys for the Sankey tab
with st.container():
cause_threshold_sankey = st.slider(
"Cause >=", min_value=1, max_value=100, value=default_cause_threshold_sankey, key="cause_threshold_sankey"
)
indicator_threshold_sankey = st.slider(
"Indicator >=", min_value=1, max_value=100, value=default_indicator_threshold_sankey, key="indicator_threshold_sankey"
) |