import streamlit as st import torch from transformers import AutoTokenizer, AutoModelForTokenClassification from annotated_text import annotated_text # Load the trained model and tokenizer model_directory = "norygano/causalBERT" tokenizer = AutoTokenizer.from_pretrained(model_directory, add_prefix_space=True) model = AutoModelForTokenClassification.from_pretrained(model_directory) # Set model to evaluation mode model.eval() # Define the label map label_map = {0: "O", 1: "B-INDICATOR", 2: "I-INDICATOR", 3: "B-CAUSE", 4: "I-CAUSE", 5: "B-EFFECT", 6: "I-EFFECT"} # Streamlit App st.markdown( """
CAUSEN V
""", unsafe_allow_html=True ) st.markdown("[Model](https://huggingface.co/norygano/causalBERT)") # Add a description with a link to the model st.write("Tags indicators and causes of explicit attributions of causality. GER only (atm)") # Text input for sentences with italic placeholder text sentences_input = st.text_area("*Sentences (one per line)*", "\n".join([ "Autos stehen im Verdacht, Waldsterben zu verursachen.", "Fußball führt zu Waldschäden.", "Haustüren tragen zum Betonsterben bei.", ]) , placeholder="Your Sentences here.") # Split the input text into individual sentences sentences = [sentence.strip() for sentence in sentences_input.splitlines() if sentence.strip()] # Button to run the model if st.button("Analyze"): for sentence in sentences: # Tokenize the sentence inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True) # Run inference with torch.no_grad(): outputs = model(**inputs) # Get the logits and predicted label IDs logits = outputs.logits predicted_label_ids = torch.argmax(logits, dim=2) # Convert token IDs back to tokens tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) # Map label IDs to human-readable labels predicted_labels = [label_map[label_id.item()] for label_id in predicted_label_ids[0]] # Reconstruct words from subwords and prepare for annotated_text annotations = [] current_word = "" current_label = "O" for token, label in zip(tokens, predicted_labels): if token in ['[CLS]', '[SEP]']: # Exclude special tokens continue if token.startswith("##"): # Append subword without "##" prefix to the current word current_word += token[2:] else: # If we have accumulated a word, add it to annotations with a space if current_word: if current_label != "O": annotations.append((current_word, current_label)) else: annotations.append(current_word) annotations.append(" ") # Add a space between words # Start a new word current_word = token current_label = label # Add the last accumulated word if current_word: if current_label != "O": annotations.append((current_word, current_label)) else: annotations.append(current_word) # Display annotated text st.write(f"**Sentence:** {sentence}") annotated_text(*annotations) st.write("---")