norygano commited on
Commit
04d4fc6
·
1 Parent(s): 96be409

Rename causenv.py to app.py for Hugging Face Spaces

Browse files
Files changed (1) hide show
  1. app.py +82 -0
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
4
+
5
+ # Load the trained model and tokenizer
6
+ model_directory = "norygano/causalBERT"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_directory, add_prefix_space=True)
8
+ model = AutoModelForTokenClassification.from_pretrained(model_directory)
9
+
10
+ # Set model to evaluation mode
11
+ model.eval()
12
+
13
+ # Define the label map
14
+ label_map = {0: "O", 1: "B-INDICATOR", 2: "I-INDICATOR", 3: "B-CAUSE", 4: "I-CAUSE"}
15
+
16
+ # Streamlit App
17
+ st.title("Attribution of Causality")
18
+ st.write("Tags indicators and causes. GER only (for now)")
19
+
20
+ # Text input for sentences
21
+ sentences_input = st.text_area("Sentences (one per line)", "\n".join([
22
+ "Laub könnte verantwortlich für den Klimawandel sein.",
23
+ "Nach dem Verursachergrundsatz spielt das keine Rolle.",
24
+ #"Backenzähne verursachen Artensterben.",
25
+ "Fußball führt zu Waldschäden.",
26
+ #"Das hängt mit vielen Faktoren zusammen.",
27
+ "Haustüren tragen zum Betonsterben bei.",
28
+ #"Autos stehen im verdacht, Bienensterben auszulösen.",
29
+ #"Lösen Straßen Waldsterben aus?"
30
+ ]))
31
+
32
+ # Split the input text into individual sentences
33
+ sentences = [sentence.strip() for sentence in sentences_input.splitlines() if sentence.strip()]
34
+
35
+ # Button to run the model
36
+ if st.button("Analyze Sentences"):
37
+ for sentence in sentences:
38
+ # Tokenize the sentence
39
+ inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
40
+
41
+ # Run inference
42
+ with torch.no_grad():
43
+ outputs = model(**inputs)
44
+
45
+ # Get the logits and predicted label IDs
46
+ logits = outputs.logits
47
+ predicted_label_ids = torch.argmax(logits, dim=2)
48
+
49
+ # Convert token IDs back to tokens
50
+ tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
51
+
52
+ # Map label IDs to human-readable labels
53
+ predicted_labels = [label_map[label_id.item()] for label_id in predicted_label_ids[0]]
54
+
55
+ # Reconstruct words from subwords
56
+ reconstructed_tokens = []
57
+ reconstructed_labels = []
58
+ for token, label in zip(tokens, predicted_labels):
59
+ if token in ['[CLS]', '[SEP]']: # Exclude special tokens
60
+ continue
61
+ if token.startswith("##"):
62
+ reconstructed_tokens[-1] += token[2:] # Append subword
63
+ else:
64
+ reconstructed_tokens.append(token)
65
+ reconstructed_labels.append(label)
66
+
67
+ # Format output with square brackets
68
+ formatted_output = []
69
+ for token, label in zip(reconstructed_tokens, reconstructed_labels):
70
+ if label != "O":
71
+ # Use square brackets around label names
72
+ formatted_output.append(f"[{label}] <b>{token}</b> [/{label}]")
73
+ else:
74
+ formatted_output.append(token)
75
+
76
+ # Join tokens for display
77
+ output_sentence = " ".join(formatted_output)
78
+
79
+ # Display formatted sentence with Streamlit
80
+ st.write(f"**Original Sentence:** {sentence}")
81
+ st.markdown(f"**Labeled Output:** {output_sentence}", unsafe_allow_html=True)
82
+ st.write("---")