File size: 5,133 Bytes
04d4fc6 4be4c1c 45d0933 9525dec 45d0933 04d4fc6 9525dec 6b22889 9525dec 04d4fc6 60e75a3 04d4fc6 45d0933 3a965ad 45d0933 3a965ad 45d0933 6b22889 04d4fc6 45d0933 9525dec 3a965ad 45d0933 6b22889 04d4fc6 45d0933 04d4fc6 45d0933 04d4fc6 45d0933 04d4fc6 45d0933 04d4fc6 45d0933 04d4fc6 45d0933 9525dec 6b22889 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from annotated_text import annotated_text
import pandas as pd
import plotly.express as px
from plot import indicator_chart, causes_chart, scatter, sankey
import os
# Define initial threshold values at the top of the script
default_cause_threshold = 20
default_indicator_threshold = 15
# Load the trained model and tokenizer
model_directory = "norygano/causalBERT"
tokenizer = AutoTokenizer.from_pretrained(model_directory, add_prefix_space=True)
model = AutoModelForTokenClassification.from_pretrained(model_directory)
model.eval()
# Define the label map
label_map = {0: "O", 1: "B-INDICATOR", 2: "I-INDICATOR", 3: "B-CAUSE", 4: "I-CAUSE", 5: "B-EFFECT", 6: "I-EFFECT"}
# Main application
st.markdown(
"""
<div style="display: flex; align-items: center; justify-content: left; font-size: 60px; font-weight: bold;">
<span>CAUSEN</span>
<span style="transform: rotate(270deg); display: inline-block; margin-left: 5px;">V</span>
</div>
""",
unsafe_allow_html=True
)
st.markdown("[Model](https://huggingface.co/norygano/causalBERT) | [Data](https://huggingface.co/datasets/norygano/causenv) | [Project](https://www.uni-trier.de/universitaet/fachbereiche-faecher/fachbereich-ii/faecher/germanistik/professurenfachteile/germanistische-linguistik/professoren/prof-dr-martin-wengeler/kontroverse-diskurse/individium-gesellschaft)")
st.write("Tags indicators and causes in explicit attributions of causality.")
# Create tabs
tab1, tab2, tab3, tab4, tab5 = st.tabs(["Prompt", "Indicators", "Causes", "Scatter", "Sankey"])
# Prompt Tab
with tab1:
sentences_input = st.text_area("*Sentences (one per line)*", "\n".join([
"Autos stehen im Verdacht, Waldsterben zu verursachen.",
"Fußball führt zu Waldschäden.",
"Haustüren tragen zum Betonsterben bei.",
]), placeholder="German only (currently)")
sentences = [sentence.strip() for sentence in sentences_input.splitlines() if sentence.strip()]
if st.button("Analyze"):
for sentence in sentences:
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_label_ids = torch.argmax(logits, dim=2)
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
predicted_labels = [label_map[label_id.item()] for label_id in predicted_label_ids[0]]
annotations = []
current_word = ""
current_label = "O"
for token, label in zip(tokens, predicted_labels):
if token in ['[CLS]', '[SEP]']: # Exclude special tokens
continue
if token.startswith("##"):
current_word += token[2:]
else:
if current_word:
if current_label != "O":
annotations.append((current_word, current_label))
else:
annotations.append(current_word)
annotations.append(" ") # Add a space between words
current_word = token
current_label = label
if current_word:
if current_label != "O":
annotations.append((current_word, current_label))
else:
annotations.append(current_word)
st.write(f"**Sentence:** {sentence}")
annotated_text(*annotations)
st.write("---")
# Research Insights Tab
with tab2:
# Overall
st.subheader("Overall")
fig_overall = indicator_chart(chart_type='overall')
st.plotly_chart(fig_overall, use_container_width=True)
# Individual Indicators Chart
st.subheader("Individual")
fig_individual = indicator_chart(chart_type='individual')
st.plotly_chart(fig_individual, use_container_width=True)
with tab3:
fig_causes = causes_chart()
st.plotly_chart(fig_causes, use_container_width=True)
with tab4:
fig_scatter = scatter()
st.plotly_chart(fig_scatter, use_container_width=True)
with tab5:
# Fixed height for the Sankey chart container
with st.container():
# Retrieve slider values and generate the diagram
cause_threshold = st.session_state.get("cause_threshold", default_cause_threshold)
indicator_threshold = st.session_state.get("indicator_threshold", default_indicator_threshold)
fig_sankey = sankey(cause_threshold=cause_threshold, indicator_threshold=indicator_threshold)
st.plotly_chart(fig_sankey, use_container_width=True)
# Place sliders below the chart container
with st.container():
cause_threshold = st.slider("Cause >", min_value=1, max_value=100, value=default_cause_threshold, key="cause_threshold")
indicator_threshold = st.slider("Indicator >", min_value=1, max_value=100, value=default_indicator_threshold, key="indicator_threshold") |