import streamlit as st import torch from transformers import AutoTokenizer, AutoModelForTokenClassification from annotated_text import annotated_text import pandas as pd import plotly.express as px from plot import indicator_chart, causes_chart, scatter_plot import os # Load the trained model and tokenizer model_directory = "norygano/causalBERT" tokenizer = AutoTokenizer.from_pretrained(model_directory, add_prefix_space=True) model = AutoModelForTokenClassification.from_pretrained(model_directory) model.eval() # Define the label map label_map = {0: "O", 1: "B-INDICATOR", 2: "I-INDICATOR", 3: "B-CAUSE", 4: "I-CAUSE", 5: "B-EFFECT", 6: "I-EFFECT"} # Main application st.markdown( """
CAUSEN V
""", unsafe_allow_html=True ) st.markdown("[Model](https://huggingface.co/norygano/causalBERT) | [Data](https://huggingface.co/datasets/norygano/causenv) | [Project](https://www.uni-trier.de/universitaet/fachbereiche-faecher/fachbereich-ii/faecher/germanistik/professurenfachteile/germanistische-linguistik/professoren/prof-dr-martin-wengeler/kontroverse-diskurse/individium-gesellschaft)") st.write("Tags indicators and causes in explicit attributions of causality. GER only (atm)") # Create tabs tab1, tab2, tab3, tab4 = st.tabs(["Prompt", "Indicators", "Causes", "Scatter"]) # Prompt Tab with tab1: sentences_input = st.text_area("*Sentences (one per line)*", "\n".join([ "Autos stehen im Verdacht, Waldsterben zu verursachen.", "Fußball führt zu Waldschäden.", "Haustüren tragen zum Betonsterben bei.", ]), placeholder="Your Sentences here.") sentences = [sentence.strip() for sentence in sentences_input.splitlines() if sentence.strip()] if st.button("Analyze"): for sentence in sentences: inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_label_ids = torch.argmax(logits, dim=2) tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) predicted_labels = [label_map[label_id.item()] for label_id in predicted_label_ids[0]] annotations = [] current_word = "" current_label = "O" for token, label in zip(tokens, predicted_labels): if token in ['[CLS]', '[SEP]']: # Exclude special tokens continue if token.startswith("##"): current_word += token[2:] else: if current_word: if current_label != "O": annotations.append((current_word, current_label)) else: annotations.append(current_word) annotations.append(" ") # Add a space between words current_word = token current_label = label if current_word: if current_label != "O": annotations.append((current_word, current_label)) else: annotations.append(current_word) st.write(f"**Sentence:** {sentence}") annotated_text(*annotations) st.write("---") # Research Insights Tab with tab2: st.write("## Indicators") # Overall st.subheader("Overall") fig_overall = indicator_chart(chart_type='overall') st.plotly_chart(fig_overall, use_container_width=True) # Individual Indicators Chart st.subheader("Individual") fig_individual = indicator_chart(chart_type='individual') st.plotly_chart(fig_individual, use_container_width=True) with tab3: st.write("## Causes") fig_causes = causes_chart() st.plotly_chart(fig_causes, use_container_width=True) with tab4: st.write("## Scatter") fig_scatter = scatter_plot() st.plotly_chart(fig_scatter, use_container_width=True)