import streamlit as st import torch from transformers import AutoTokenizer, AutoModelForTokenClassification # Load the trained model and tokenizer model_directory = "norygano/causalBERT" tokenizer = AutoTokenizer.from_pretrained(model_directory, add_prefix_space=True) model = AutoModelForTokenClassification.from_pretrained(model_directory) # Set model to evaluation mode model.eval() # Define the label map label_map = {0: "O", 1: "B-INDICATOR", 2: "I-INDICATOR", 3: "B-CAUSE", 4: "I-CAUSE"} # Streamlit App st.title("Attribution of Causality") st.write("Tags indicators and causes. GER only (for now)") # Text input for sentences sentences_input = st.text_area("Sentences (one per line)", "\n".join([ "Laub könnte verantwortlich für den Klimawandel sein.", "Nach dem Verursachergrundsatz spielt das keine Rolle.", #"Backenzähne verursachen Artensterben.", "Fußball führt zu Waldschäden.", #"Das hängt mit vielen Faktoren zusammen.", "Haustüren tragen zum Betonsterben bei.", #"Autos stehen im verdacht, Bienensterben auszulösen.", #"Lösen Straßen Waldsterben aus?" ])) # Split the input text into individual sentences sentences = [sentence.strip() for sentence in sentences_input.splitlines() if sentence.strip()] # Button to run the model if st.button("Analyze Sentences"): for sentence in sentences: # Tokenize the sentence inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True) # Run inference with torch.no_grad(): outputs = model(**inputs) # Get the logits and predicted label IDs logits = outputs.logits predicted_label_ids = torch.argmax(logits, dim=2) # Convert token IDs back to tokens tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) # Map label IDs to human-readable labels predicted_labels = [label_map[label_id.item()] for label_id in predicted_label_ids[0]] # Reconstruct words from subwords reconstructed_tokens = [] reconstructed_labels = [] for token, label in zip(tokens, predicted_labels): if token in ['[CLS]', '[SEP]']: # Exclude special tokens continue if token.startswith("##"): reconstructed_tokens[-1] += token[2:] # Append subword else: reconstructed_tokens.append(token) reconstructed_labels.append(label) # Format output with square brackets formatted_output = [] for token, label in zip(reconstructed_tokens, reconstructed_labels): if label != "O": # Use square brackets around label names formatted_output.append(f"[{label}] {token} [/{label}]") else: formatted_output.append(token) # Join tokens for display output_sentence = " ".join(formatted_output) # Display formatted sentence with Streamlit st.write(f"**Original Sentence:** {sentence}") st.markdown(f"**Labeled Output:** {output_sentence}", unsafe_allow_html=True) st.write("---")