File size: 4,244 Bytes
04d4fc6 4be4c1c 45d0933 04d4fc6 60e75a3 04d4fc6 45d0933 3a965ad 45d0933 3a965ad 45d0933 04d4fc6 45d0933 3a965ad 45d0933 04d4fc6 45d0933 04d4fc6 45d0933 04d4fc6 45d0933 04d4fc6 45d0933 04d4fc6 45d0933 04d4fc6 45d0933 04d4fc6 45d0933 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from annotated_text import annotated_text
import pandas as pd
import plotly.express as px
from plot import indicator_chart, causes_chart, scatter_plot
import os
# Load the trained model and tokenizer
model_directory = "norygano/causalBERT"
tokenizer = AutoTokenizer.from_pretrained(model_directory, add_prefix_space=True)
model = AutoModelForTokenClassification.from_pretrained(model_directory)
model.eval()
# Define the label map
label_map = {0: "O", 1: "B-INDICATOR", 2: "I-INDICATOR", 3: "B-CAUSE", 4: "I-CAUSE", 5: "B-EFFECT", 6: "I-EFFECT"}
# Main application
st.markdown(
"""
<div style="display: flex; align-items: center; justify-content: left; font-size: 60px; font-weight: bold;">
<span>CAUSEN</span>
<span style="transform: rotate(270deg); display: inline-block; margin-left: 5px;">V</span>
</div>
""",
unsafe_allow_html=True
)
st.markdown("[Model](https://huggingface.co/norygano/causalBERT) | [Data](https://huggingface.co/datasets/norygano/causenv) | [Project](https://www.uni-trier.de/universitaet/fachbereiche-faecher/fachbereich-ii/faecher/germanistik/professurenfachteile/germanistische-linguistik/professoren/prof-dr-martin-wengeler/kontroverse-diskurse/individium-gesellschaft)")
st.write("Tags indicators and causes in explicit attributions of causality. GER only (atm)")
# Create tabs
tab1, tab2, tab3, tab4 = st.tabs(["Prompt", "Indicators", "Causes", "Scatter"])
# Prompt Tab
with tab1:
sentences_input = st.text_area("*Sentences (one per line)*", "\n".join([
"Autos stehen im Verdacht, Waldsterben zu verursachen.",
"Fußball führt zu Waldschäden.",
"Haustüren tragen zum Betonsterben bei.",
]), placeholder="Your Sentences here.")
sentences = [sentence.strip() for sentence in sentences_input.splitlines() if sentence.strip()]
if st.button("Analyze"):
for sentence in sentences:
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_label_ids = torch.argmax(logits, dim=2)
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
predicted_labels = [label_map[label_id.item()] for label_id in predicted_label_ids[0]]
annotations = []
current_word = ""
current_label = "O"
for token, label in zip(tokens, predicted_labels):
if token in ['[CLS]', '[SEP]']: # Exclude special tokens
continue
if token.startswith("##"):
current_word += token[2:]
else:
if current_word:
if current_label != "O":
annotations.append((current_word, current_label))
else:
annotations.append(current_word)
annotations.append(" ") # Add a space between words
current_word = token
current_label = label
if current_word:
if current_label != "O":
annotations.append((current_word, current_label))
else:
annotations.append(current_word)
st.write(f"**Sentence:** {sentence}")
annotated_text(*annotations)
st.write("---")
# Research Insights Tab
with tab2:
st.write("## Indicators")
# Overall
st.subheader("Overall")
fig_overall = indicator_chart(chart_type='overall')
st.plotly_chart(fig_overall, use_container_width=True)
# Individual Indicators Chart
st.subheader("Individual")
fig_individual = indicator_chart(chart_type='individual')
st.plotly_chart(fig_individual, use_container_width=True)
with tab3:
st.write("## Causes")
fig_causes = causes_chart()
st.plotly_chart(fig_causes, use_container_width=True)
with tab4:
st.write("## Scatter")
fig_scatter = scatter_plot()
st.plotly_chart(fig_scatter, use_container_width=True) |