anamargarida commited on
Commit
2ce452b
·
verified ·
1 Parent(s): c216dc0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +209 -0
app.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from safetensors.torch import load_file
4
+ from transformers import AutoConfig, AutoTokenizer, AutoModel
5
+ from ST2ModelV2_6 import ST2ModelV2
6
+ from huggingface_hub import login
7
+ import re
8
+ import copy
9
+
10
+ hf_token = st.secrets["HUGGINGFACE_TOKEN"]
11
+ login(token=hf_token)
12
+
13
+
14
+
15
+ # Load model & tokenizer once (cached for efficiency)
16
+ @st.cache_resource
17
+ def load_model():
18
+
19
+ model_name = "anamargarida/Final"
20
+
21
+ config = AutoConfig.from_pretrained(model_name)
22
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
23
+
24
+ class Args:
25
+ def __init__(self):
26
+ self.model_name = model_name
27
+ self.dropout = 0.1
28
+ self.signal_classification = True
29
+ self.pretrained_signal_detector = False
30
+
31
+ args = Args()
32
+
33
+ # Load the model directly from Hugging Face
34
+ model = ST2ModelV2.from_pretrained(model_name, config=config, args=args)
35
+
36
+
37
+ return tokenizer, model
38
+
39
+ # Load the model and tokenizer
40
+ tokenizer, model = load_model()
41
+
42
+ st.write("model_", model)
43
+ st.write("model_weights", model.model)
44
+ st.write("config", model.config)
45
+ st.write("Signal_classifier_weights", model.signal_classifier.weight)
46
+ st.write(model.model.embeddings.LayerNorm.weight)
47
+ #st.write(model.model.encoder.layer.13.attention.self.value.weight)
48
+ roberta_model = AutoModel.from_pretrained("roberta-large")
49
+ st.write(roberta_model.embeddings.LayerNorm.weight)
50
+
51
+ model.eval() # Set model to evaluation mode
52
+ def extract_arguments(text, tokenizer, model, beam_search=True):
53
+
54
+ class Args:
55
+ def __init__(self):
56
+ self.signal_classification = True
57
+ self.pretrained_signal_detector = False
58
+
59
+ args = Args()
60
+ inputs = tokenizer(text, return_tensors="pt")
61
+
62
+ with torch.no_grad():
63
+ outputs = model(**inputs)
64
+
65
+ #st.write("Model output keys:", outputs.keys())
66
+
67
+ # Extract logits
68
+ start_cause_logits = outputs["start_arg0_logits"][0]
69
+ end_cause_logits = outputs["end_arg0_logits"][0]
70
+ start_effect_logits = outputs["start_arg1_logits"][0]
71
+ end_effect_logits = outputs["end_arg1_logits"][0]
72
+ start_signal_logits = outputs["start_sig_logits"][0]
73
+ end_signal_logits = outputs["end_sig_logits"][0]
74
+
75
+ #st.write("start_cause_logits", start_cause_logits)
76
+ #st.write("end_cause_logits", end_cause_logits)
77
+ #st.write("start_effect_logits", start_effect_logits)
78
+ #st.write("end_effect_logits", end_effect_logits)
79
+ #st.write("start_signal_logits", start_signal_logits)
80
+ #st.write("end_signal_logits", end_signal_logits)
81
+
82
+
83
+ # Set the first and last token logits to a very low value to ignore them
84
+ start_cause_logits[0] = -1e-4
85
+ end_cause_logits[0] = -1e-4
86
+ start_effect_logits[0] = -1e-4
87
+ end_effect_logits[0] = -1e-4
88
+ start_cause_logits[len(inputs["input_ids"][0]) - 1] = -1e-4
89
+ end_cause_logits[len(inputs["input_ids"][0]) - 1] = -1e-4
90
+ start_effect_logits[len(inputs["input_ids"][0]) - 1] = -1e-4
91
+ end_effect_logits[len(inputs["input_ids"][0]) - 1] = -1e-4
92
+
93
+ st.write("start_cause_logits", start_cause_logits)
94
+ st.write("end_cause_logits", end_cause_logits)
95
+ st.write("start_effect_logits", start_effect_logits)
96
+ st.write("end_effect_logits", end_effect_logits)
97
+ st.write("start_signal_logits", start_signal_logits)
98
+ st.write("end_signal_logits", end_signal_logits)
99
+
100
+ # Beam Search for position selection
101
+ if beam_search:
102
+ indices1, indices2, _, _, _ = model.beam_search_position_selector(
103
+ start_cause_logits=start_cause_logits,
104
+ end_cause_logits=end_cause_logits,
105
+ start_effect_logits=start_effect_logits,
106
+ end_effect_logits=end_effect_logits,
107
+ topk=5
108
+ )
109
+ start_cause1, end_cause1, start_effect1, end_effect1 = indices1
110
+ start_cause2, end_cause2, start_effect2, end_effect2 = indices2
111
+ else:
112
+ start_cause1 = start_cause_logits.argmax().item()
113
+ end_cause1 = end_cause_logits.argmax().item()
114
+ start_effect1 = start_effect_logits.argmax().item()
115
+ end_effect1 = end_effect_logits.argmax().item()
116
+
117
+ start_cause2, end_cause2, start_effect2, end_effect2 = None, None, None, None
118
+
119
+
120
+ has_signal = 1
121
+ if args.signal_classification:
122
+ if not args.pretrained_signal_detector:
123
+ has_signal = outputs["signal_classification_logits"].argmax().item()
124
+ else:
125
+ has_signal = signal_detector.predict(text=batch["text"])
126
+
127
+ if has_signal:
128
+ start_signal_logits[0] = -1e-4
129
+ end_signal_logits[0] = -1e-4
130
+
131
+ start_signal_logits[len(inputs["input_ids"][0]) - 1] = -1e-4
132
+ end_signal_logits[len(inputs["input_ids"][0]) - 1] = -1e-4
133
+
134
+ start_signal = start_signal_logits.argmax().item()
135
+ end_signal_logits[:start_signal] = -1e4
136
+ end_signal_logits[start_signal + 5:] = -1e4
137
+ end_signal = end_signal_logits.argmax().item()
138
+
139
+ if not has_signal:
140
+ start_signal = 'NA'
141
+ end_signal = 'NA'
142
+
143
+
144
+ tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
145
+ token_ids = inputs["input_ids"][0]
146
+
147
+ #st.write("Token Positions, IDs, and Corresponding Tokens:")
148
+ #for position, (token_id, token) in enumerate(zip(token_ids, tokens)):
149
+ #st.write(f"Position: {position}, ID: {token_id}, Token: {token}")
150
+
151
+ st.write(f"Start Cause 1: {start_cause1}, End Cause: {end_cause1}")
152
+ st.write(f"Start Effect 1: {start_effect1}, End Cause: {end_effect1}")
153
+ st.write(f"Start Signal: {start_signal}, End Signal: {end_signal}")
154
+
155
+ def extract_span(start, end):
156
+ return tokenizer.convert_tokens_to_string(tokens[start:end+1]) if start is not None and end is not None else ""
157
+
158
+ cause1 = extract_span(start_cause1, end_cause1)
159
+ cause2 = extract_span(start_cause2, end_cause2)
160
+ effect1 = extract_span(start_effect1, end_effect1)
161
+ effect2 = extract_span(start_effect2, end_effect2)
162
+ if has_signal:
163
+ signal = extract_span(start_signal, end_signal)
164
+ if not has_signal:
165
+ signal = 'NA'
166
+ list1 = [start_cause1, end_cause1, start_effect1, end_effect1, start_signal, end_signal]
167
+ list2 = [start_cause2, end_cause2, start_effect2, end_effect2, start_signal, end_signal]
168
+ return cause1, cause2, effect1, effect2, signal, list1, list2
169
+
170
+ def mark_text(original_text, span, color):
171
+ """Replace extracted span with a colored background marker."""
172
+ if span:
173
+ return re.sub(re.escape(span), f"<mark style='background-color:{color}; padding:2px; border-radius:4px;'>{span}</mark>", original_text, flags=re.IGNORECASE)
174
+ return original_text # Return unchanged text if no span is found
175
+
176
+ st.title("Causal Relation Extraction")
177
+ input_text = st.text_area("Enter your text here:", height=300)
178
+ beam_search = st.radio("Enable Beam Search?", ('No', 'Yes')) == 'Yes'
179
+
180
+
181
+ if st.button("Extract1"):
182
+ if input_text:
183
+ cause1, cause2, effect1, effect2, signal, list1, list2 = extract_arguments(input_text, tokenizer, model, beam_search=beam_search)
184
+
185
+ cause_text1 = mark_text(input_text, cause1, "#FFD700") # Gold for cause
186
+ effect_text1 = mark_text(input_text, effect1, "#90EE90") # Light green for effect
187
+ signal_text = mark_text(input_text, signal, "#FF6347") # Tomato red for signal
188
+
189
+ st.markdown(f"<span style='font-size: 24px;'><strong>Relation 1:</strong></span>", unsafe_allow_html=True)
190
+ st.markdown(f"**Cause:**<br>{cause_text1}", unsafe_allow_html=True)
191
+ st.markdown(f"**Effect:**<br>{effect_text1}", unsafe_allow_html=True)
192
+ st.markdown(f"**Signal:**<br>{signal_text}", unsafe_allow_html=True)
193
+
194
+ #st.write("List 1:", list1)
195
+
196
+ if beam_search:
197
+
198
+ cause_text2 = mark_text(input_text, cause2, "#FFD700") # Gold for cause
199
+ effect_text2 = mark_text(input_text, effect2, "#90EE90") # Light green for effect
200
+ signal_text = mark_text(input_text, signal, "#FF6347") # Tomato red for signal
201
+
202
+ st.markdown(f"<span style='font-size: 24px;'><strong>Relation 2:</strong></span>", unsafe_allow_html=True)
203
+ st.markdown(f"**Cause:**<br>{cause_text2}", unsafe_allow_html=True)
204
+ st.markdown(f"**Effect:**<br>{effect_text2}", unsafe_allow_html=True)
205
+ st.markdown(f"**Signal:**<br>{signal_text}", unsafe_allow_html=True)
206
+
207
+ #st.write("List 2:", list2)
208
+ else:
209
+ st.warning("Please enter some text before extracting.")