norygano commited on
Commit
460aadf
·
1 Parent(s): 1766897
Files changed (1) hide show
  1. causenv.py +0 -82
causenv.py DELETED
@@ -1,82 +0,0 @@
1
- import streamlit as st
2
- import torch
3
- from transformers import AutoTokenizer, AutoModelForTokenClassification
4
-
5
- # Load the trained model and tokenizer
6
- model_directory = "norygano/causalBERT"
7
- tokenizer = AutoTokenizer.from_pretrained(model_directory, add_prefix_space=True)
8
- model = AutoModelForTokenClassification.from_pretrained(model_directory)
9
-
10
- # Set model to evaluation mode
11
- model.eval()
12
-
13
- # Define the label map
14
- label_map = {0: "O", 1: "B-INDICATOR", 2: "I-INDICATOR", 3: "B-CAUSE", 4: "I-CAUSE"}
15
-
16
- # Streamlit App
17
- st.title("Attribution of Causality")
18
- st.write("Tags indicators and causes. GER only (for now)")
19
-
20
- # Text input for sentences
21
- sentences_input = st.text_area("Sentences (one per line)", "\n".join([
22
- "Laub könnte verantwortlich für den Klimawandel sein.",
23
- "Nach dem Verursachergrundsatz spielt das keine Rolle.",
24
- #"Backenzähne verursachen Artensterben.",
25
- "Fußball führt zu Waldschäden.",
26
- #"Das hängt mit vielen Faktoren zusammen.",
27
- "Haustüren tragen zum Betonsterben bei.",
28
- #"Autos stehen im verdacht, Bienensterben auszulösen.",
29
- #"Lösen Straßen Waldsterben aus?"
30
- ]))
31
-
32
- # Split the input text into individual sentences
33
- sentences = [sentence.strip() for sentence in sentences_input.splitlines() if sentence.strip()]
34
-
35
- # Button to run the model
36
- if st.button("Analyze Sentences"):
37
- for sentence in sentences:
38
- # Tokenize the sentence
39
- inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
40
-
41
- # Run inference
42
- with torch.no_grad():
43
- outputs = model(**inputs)
44
-
45
- # Get the logits and predicted label IDs
46
- logits = outputs.logits
47
- predicted_label_ids = torch.argmax(logits, dim=2)
48
-
49
- # Convert token IDs back to tokens
50
- tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
51
-
52
- # Map label IDs to human-readable labels
53
- predicted_labels = [label_map[label_id.item()] for label_id in predicted_label_ids[0]]
54
-
55
- # Reconstruct words from subwords
56
- reconstructed_tokens = []
57
- reconstructed_labels = []
58
- for token, label in zip(tokens, predicted_labels):
59
- if token in ['[CLS]', '[SEP]']: # Exclude special tokens
60
- continue
61
- if token.startswith("##"):
62
- reconstructed_tokens[-1] += token[2:] # Append subword
63
- else:
64
- reconstructed_tokens.append(token)
65
- reconstructed_labels.append(label)
66
-
67
- # Format output with square brackets
68
- formatted_output = []
69
- for token, label in zip(reconstructed_tokens, reconstructed_labels):
70
- if label != "O":
71
- # Use square brackets around label names
72
- formatted_output.append(f"[{label}] <b>{token}</b> [/{label}]")
73
- else:
74
- formatted_output.append(token)
75
-
76
- # Join tokens for display
77
- output_sentence = " ".join(formatted_output)
78
-
79
- # Display formatted sentence with Streamlit
80
- st.write(f"**Original Sentence:** {sentence}")
81
- st.markdown(f"**Labeled Output:** {output_sentence}", unsafe_allow_html=True)
82
- st.write("---")