import streamlit as st import torch from transformers import AutoTokenizer, AutoModelForTokenClassification from annotated_text import annotated_text from plot import Plot from flow import FlowChart import os # Define initial threshold values at the top of the script default_cause_threshold = 25 default_indicator_threshold = 15 default_cause_threshold_sankey = 20 default_indicator_threshold_sankey = 15 # Initialize Plots plot = Plot() flow_chart = FlowChart() # Load the trained model and tokenizer model_directory = "norygano/causalBERT" tokenizer = AutoTokenizer.from_pretrained(model_directory, add_prefix_space=True) model = AutoModelForTokenClassification.from_pretrained(model_directory) model.eval() # Define the label map label_map = {0: "O", 1: "B-INDICATOR", 2: "I-INDICATOR", 3: "B-CAUSE", 4: "I-CAUSE", 5: "B-EFFECT", 6: "I-EFFECT"} # Main application st.markdown( """
CAUSE V
""", unsafe_allow_html=True ) st.markdown("[Weights](https://huggingface.co/norygano/causalBERT) | [Data](https://huggingface.co/datasets/norygano/causenv)") st.write("Indicators and causes in explicit attributions of causality.") # Create tabs tab1, tab2, tab3, tab4, tab5 = st.tabs(["Prompt", "Indicators", "Causes", "Scatter", "Sankey"]) # Prompt Tab with tab1: sentences_input = st.text_area("*Sentences (one per line)*", "\n".join([ "Autos stehen im Verdacht, Waldsterben verursacht zu haben.", "Fußball führt zu Waldschäden.", "Haustüren tragen zum Betonsterben bei.", ]), placeholder="German only (currently)") sentences = [sentence.strip() for sentence in sentences_input.splitlines() if sentence.strip()] if st.button("Analyze"): for sentence in sentences: inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_label_ids = torch.argmax(logits, dim=2) tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) predicted_labels = [label_map[label_id.item()] for label_id in predicted_label_ids[0]] annotations = [] current_word = "" current_label = "O" for token, label in zip(tokens, predicted_labels): if token in ['[CLS]', '[SEP]']: # Exclude special tokens continue if token.startswith("##"): current_word += token[2:] else: if current_word: if current_label != "O": annotations.append((current_word, current_label)) else: annotations.append(current_word) annotations.append(" ") # Add a space between words current_word = token current_label = label if current_word: if current_label != "O": annotations.append((current_word, current_label)) else: annotations.append(current_word) st.write(f"**Sentence:** {sentence}") annotated_text(*annotations) st.write("---") # Indicator Tab with tab2: selected_chart_type = st.radio( label="Type", label_visibility='collapsed', options=['Total', 'Year', 'Individual'], horizontal=True, ) # Display the chart in a container with st.container(): if selected_chart_type == 'Individual': # Retrieve slider value from session state or use default individual_threshold = st.session_state.get("individual_threshold", default_indicator_threshold) fig = plot.get_indicator_chart(chart_type=selected_chart_type.lower(), individual_threshold=individual_threshold) else: fig = plot.get_indicator_chart(chart_type=selected_chart_type.lower()) st.plotly_chart(fig, use_container_width=True) # Display the slider below the chart container for 'Individual' type if selected_chart_type == 'Individual': with st.container(): individual_threshold = st.slider( "Indicator >=", min_value=1, max_value=95, value=default_indicator_threshold, key="individual_threshold" ) # Causes Tab with tab3: # Create a container for the chart and place the slider below it with st.container(): # Display the chart first fig_causes = plot.get_causes_chart(min_value=st.session_state.get("cause_threshold_causes", default_cause_threshold)) st.plotly_chart(fig_causes, use_container_width=True) # Place the slider below the chart with a unique key cause_threshold_causes = st.slider( "Cause >=", min_value=1, max_value=75, value=default_cause_threshold, key="cause_threshold_causes" ) # Scatter Tab with tab4: fig_scatter = plot.scatter() st.plotly_chart(fig_scatter, use_container_width=True) # Sankey Tab with tab5: with st.container(): # Use the unique Sankey threshold variables in session state cause_threshold_sankey = st.session_state.get("cause_threshold_sankey", default_cause_threshold_sankey) indicator_threshold_sankey = st.session_state.get("indicator_threshold_sankey", default_indicator_threshold_sankey) # Generate the Sankey diagram with the new Sankey-specific thresholds fig_sankey = plot.sankey(cause_threshold=cause_threshold_sankey, indicator_threshold=indicator_threshold_sankey) st.plotly_chart(fig_sankey, use_container_width=True) # Place sliders below the chart container with unique keys for the Sankey tab with st.container(): cause_threshold_sankey = st.slider( "Cause >=", min_value=1, max_value=100, value=default_cause_threshold_sankey, key="cause_threshold_sankey" ) indicator_threshold_sankey = st.slider( "Indicator >=", min_value=1, max_value=100, value=default_indicator_threshold_sankey, key="indicator_threshold_sankey" )