anamargarida commited on
Commit
1f69012
·
verified ·
1 Parent(s): 5eec8e8

Delete app_27.py

Browse files
Files changed (1) hide show
  1. app_27.py +0 -502
app_27.py DELETED
@@ -1,502 +0,0 @@
1
- import streamlit as st
2
- import torch
3
- from transformers import AutoConfig, AutoTokenizer, AutoModel
4
- from huggingface_hub import login
5
- import re
6
- import copy
7
- from modeling_st2 import ST2ModelV2, SignalDetector
8
- from huggingface_hub import hf_hub_download
9
- from safetensors.torch import load_file
10
-
11
- hf_token = st.secrets["HUGGINGFACE_TOKEN"]
12
- login(token=hf_token)
13
-
14
-
15
-
16
- # Load model & tokenizer once (cached for efficiency)
17
- @st.cache_resource
18
- def load_model():
19
-
20
- config = AutoConfig.from_pretrained("roberta-large")
21
-
22
- tokenizer = AutoTokenizer.from_pretrained("roberta-large", use_fast=True, add_prefix_space=True)
23
-
24
- class Args:
25
- def __init__(self):
26
-
27
- self.dropout = 0.1
28
- self.signal_classification = True
29
- self.pretrained_signal_detector = False
30
-
31
- args = Args()
32
-
33
- model = ST2ModelV2(args)
34
-
35
-
36
- repo_id = "anamargarida/SpanExtractionWithSignalCls_2"
37
- filename = "model.safetensors"
38
-
39
- # Download the model file
40
- model_path = hf_hub_download(repo_id=repo_id, filename=filename)
41
-
42
- # Load the model weights
43
- state_dict = load_file(model_path)
44
-
45
- model.load_state_dict(state_dict)
46
-
47
- return tokenizer, model
48
-
49
- # Load the model and tokenizer
50
- tokenizer, model = load_model()
51
-
52
-
53
-
54
- model.eval() # Set model to evaluation mode
55
- def extract_arguments(text, tokenizer, model, beam_search=True):
56
-
57
- class Args:
58
- def __init__(self):
59
- self.signal_classification = True
60
- self.pretrained_signal_detector = False
61
-
62
- args = Args()
63
- inputs = tokenizer(text, return_offsets_mapping=True, return_tensors="pt")
64
-
65
- # Get tokenized words (for reconstruction later)
66
- word_ids = inputs.word_ids()
67
-
68
- with torch.no_grad():
69
- outputs = model(**inputs)
70
-
71
-
72
- # Extract logits
73
- start_cause_logits = outputs["start_arg0_logits"][0]
74
- end_cause_logits = outputs["end_arg0_logits"][0]
75
- start_effect_logits = outputs["start_arg1_logits"][0]
76
- end_effect_logits = outputs["end_arg1_logits"][0]
77
- start_signal_logits = outputs["start_sig_logits"][0]
78
- end_signal_logits = outputs["end_sig_logits"][0]
79
-
80
-
81
- # Set the first and last token logits to a very low value to ignore them
82
- start_cause_logits[0] = -1e-4
83
- end_cause_logits[0] = -1e-4
84
- start_effect_logits[0] = -1e-4
85
- end_effect_logits[0] = -1e-4
86
- start_cause_logits[len(inputs["input_ids"][0]) - 1] = -1e-4
87
- end_cause_logits[len(inputs["input_ids"][0]) - 1] = -1e-4
88
- start_effect_logits[len(inputs["input_ids"][0]) - 1] = -1e-4
89
- end_effect_logits[len(inputs["input_ids"][0]) - 1] = -1e-4
90
-
91
-
92
- # Beam Search for position selection
93
- if beam_search:
94
- indices1, indices2, _, _, _ = model.beam_search_position_selector(
95
- start_cause_logits=start_cause_logits,
96
- end_cause_logits=end_cause_logits,
97
- start_effect_logits=start_effect_logits,
98
- end_effect_logits=end_effect_logits,
99
- topk=5
100
- )
101
- start_cause1, end_cause1, start_effect1, end_effect1 = indices1
102
- start_cause2, end_cause2, start_effect2, end_effect2 = indices2
103
- else:
104
- start_cause1 = start_cause_logits.argmax().item()
105
- end_cause1 = end_cause_logits.argmax().item()
106
- start_effect1 = start_effect_logits.argmax().item()
107
- end_effect1 = end_effect_logits.argmax().item()
108
-
109
- start_cause2, end_cause2, start_effect2, end_effect2 = None, None, None, None
110
-
111
-
112
- has_signal = 1
113
- if args.signal_classification:
114
- if not args.pretrained_signal_detector:
115
- has_signal = outputs["signal_classification_logits"].argmax().item()
116
- else:
117
- has_signal = signal_detector.predict(text=batch["text"])
118
-
119
- if has_signal:
120
- start_signal_logits[0] = -1e-4
121
- end_signal_logits[0] = -1e-4
122
-
123
- start_signal_logits[len(inputs["input_ids"][0]) - 1] = -1e-4
124
- end_signal_logits[len(inputs["input_ids"][0]) - 1] = -1e-4
125
-
126
- start_signal = start_signal_logits.argmax().item()
127
- end_signal_logits[:start_signal] = -1e4
128
- end_signal_logits[start_signal + 5:] = -1e4
129
- end_signal = end_signal_logits.argmax().item()
130
-
131
- if not has_signal:
132
- start_signal, end_signal = None, None
133
-
134
-
135
- tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
136
- token_ids = inputs["input_ids"][0]
137
- offset_mapping = inputs["offset_mapping"][0].tolist()
138
-
139
- for i, (token, word_id) in enumerate(zip(tokens, word_ids)):
140
- st.write(f"Token {i}: {token}, Word ID: {word_id}")
141
-
142
- st.write("Token & offset:")
143
- for i, (token, offset) in enumerate(zip(tokens, offset_mapping)):
144
- st.write(f"Token {i}: {token}, Offset: {offset}")
145
-
146
-
147
- st.write("Token Positions, IDs, and Corresponding Tokens:")
148
- for position, (token_id, token) in enumerate(zip(token_ids, tokens)):
149
- st.write(f"Position: {position}, ID: {token_id}, Token: {token}")
150
-
151
- st.write(f"Start Cause 1: {start_cause1}, End Cause: {end_cause1}")
152
- st.write(f"Start Effect 1: {start_effect1}, End Cause: {end_effect1}")
153
- st.write(f"Start Signal: {start_signal}, End Signal: {end_signal}")
154
-
155
- def extract_span(start, end):
156
- return tokenizer.convert_tokens_to_string(tokens[start:end+1]) if start is not None and end is not None else ""
157
-
158
- cause1 = extract_span(start_cause1, end_cause1)
159
- cause2 = extract_span(start_cause2, end_cause2)
160
- effect1 = extract_span(start_effect1, end_effect1)
161
- effect2 = extract_span(start_effect2, end_effect2)
162
- if has_signal:
163
- signal = extract_span(start_signal, end_signal)
164
- if not has_signal:
165
- signal = 'NA'
166
- list1 = [start_cause1, end_cause1, start_effect1, end_effect1, start_signal, end_signal]
167
- list2 = [start_cause2, end_cause2, start_effect2, end_effect2, start_signal, end_signal]
168
- #return cause1, cause2, effect1, effect2, signal, list1, list2
169
- #return start_cause1, end_cause1, start_cause2, end_cause2, start_effect1, end_effect1, start_effect2, end_effect2, start_signal, end_signal
170
-
171
- # Find the first valid token in a multi-token word
172
- def find_valid_start(position):
173
- while position > 0 and word_ids[position] == word_ids[position - 1]:
174
- position -= 1
175
- return position
176
-
177
- def find_valid_end(position):
178
- while position < len(word_ids) - 1 and word_ids[position] == word_ids[position + 1]:
179
- position += 1
180
- return position
181
-
182
-
183
- # Add the argument tags in the sentence directly
184
- def add_tags(original_text, word_ids, start_cause, end_cause, start_effect, end_effect, start_signal, end_signal):
185
- space_splitted_tokens = original_text.split(" ")
186
- this_space_splitted_tokens = copy.deepcopy(space_splitted_tokens)
187
-
188
- def safe_insert(tag, position, start=True):
189
- """Safely insert a tag, checking for None values and index validity."""
190
- if position is not None and word_ids[position] is not None:
191
- word_index = word_ids[position]
192
-
193
- # Ensure word_index is within range
194
- if 0 <= word_index < len(this_space_splitted_tokens):
195
- if start:
196
- this_space_splitted_tokens[word_index] = tag + this_space_splitted_tokens[word_index]
197
- else:
198
- this_space_splitted_tokens[word_index] += tag
199
-
200
- # Add argument tags safely
201
- safe_insert('<ARG0>', start_cause, start=True)
202
- safe_insert('</ARG0>', end_cause, start=False)
203
- safe_insert('<ARG1>', start_effect, start=True)
204
- safe_insert('</ARG1>', end_effect, start=False)
205
-
206
- # Add signal tags safely (if signal exists)
207
- if start_signal is not None and end_signal is not None:
208
- safe_insert('<SIG0>', start_signal, start=True)
209
- safe_insert('</SIG0>', end_signal, start=False)
210
-
211
- # Join tokens back into a string
212
- return ' '.join(this_space_splitted_tokens)
213
-
214
- def add_tags_find(original_text, word_ids, start_cause, end_cause, start_effect, end_effect, start_signal, end_signal):
215
- space_splitted_tokens = original_text.split(" ")
216
- this_space_splitted_tokens = copy.deepcopy(space_splitted_tokens)
217
-
218
- def safe_insert(tag, position, start=True):
219
- """Safely insert a tag, checking for None values and index validity."""
220
- if position is not None and word_ids[position] is not None:
221
- word_index = word_ids[position]
222
-
223
- # Ensure word_index is within range
224
- if 0 <= word_index < len(this_space_splitted_tokens):
225
- if start:
226
- this_space_splitted_tokens[word_index] = tag + this_space_splitted_tokens[word_index]
227
- else:
228
- this_space_splitted_tokens[word_index] += tag
229
-
230
- # Find valid start and end positions for words
231
- start_cause = find_valid_start(start_cause)
232
- end_cause = find_valid_end(end_cause)
233
- start_effect = find_valid_start(start_effect)
234
- end_effect = find_valid_end(end_effect)
235
- if start_signal is not None:
236
- start_signal = find_valid_start(start_signal)
237
- end_signal = find_valid_end(end_signal)
238
-
239
- # Adjust for punctuation shifts
240
- if tokens[end_cause] in [".", ",", "-", ":", ";"]:
241
- end_cause -= 1
242
- if tokens[end_effect] in [".", ",", "-", ":", ";"]:
243
- end_effect -= 1
244
-
245
- # Add argument tags safely
246
- safe_insert('<ARG0>', start_cause, start=True)
247
- safe_insert('</ARG0>', end_cause, start=False)
248
- safe_insert('<ARG1>', start_effect, start=True)
249
- safe_insert('</ARG1>', end_effect, start=False)
250
-
251
- # Add signal tags safely (if signal exists)
252
- if start_signal is not None and end_signal is not None:
253
- safe_insert('<SIG0>', start_signal, start=True)
254
- safe_insert('</SIG0>', end_signal, start=False)
255
-
256
- # Join tokens back into a string
257
- return ' '.join(this_space_splitted_tokens)
258
-
259
- def add_tags_offset(text, start_cause, end_cause, start_effect, end_effect, start_signal, end_signal):
260
- """
261
- Inserts tags into the original text based on token offsets.
262
-
263
- Args:
264
- text (str): The original input text.
265
- tokenizer: The tokenizer used for tokenization.
266
- start_cause (int): Start token index of the cause span.
267
- end_cause (int): End token index of the cause span.
268
- start_effect (int): Start token index of the effect span.
269
- end_effect (int): End token index of the effect span.
270
- start_signal (int, optional): Start token index of the signal span.
271
- end_signal (int, optional): End token index of the signal span.
272
-
273
- Returns:
274
- str: The modified text with annotated spans.
275
- """
276
-
277
-
278
-
279
- # Convert token-based indices to character-based indices
280
- start_cause_char, end_cause_char = offset_mapping[start_cause][0], offset_mapping[end_cause][1]
281
- start_effect_char, end_effect_char = offset_mapping[start_effect][0], offset_mapping[end_effect][1]
282
-
283
- # Insert tags into the original text
284
- annotated_text = text[:start_cause_char] + "<ARG0>" + text[start_cause_char:end_cause_char] + "</ARG0>" + text[end_cause_char:start_effect_char] + "<ARG1>" + text[start_effect_char:end_effect_char] + "</ARG1>" + text[end_effect_char:]
285
-
286
- # If signal span exists, insert signal tags
287
- if start_signal is not None and end_signal is not None:
288
- start_signal_char, end_signal_char = offset_mapping[start_signal][0], offset_mapping[end_signal][1]
289
- annotated_text = (
290
- annotated_text[:start_signal_char]
291
- + "<SIG0>" + annotated_text[start_signal_char:end_signal_char] + "</SIG0>"
292
- + annotated_text[end_signal_char:]
293
- )
294
-
295
- return annotated_text
296
-
297
- def add_tags_offset_2(text, start_cause, end_cause, start_effect, end_effect, start_signal, end_signal):
298
- """
299
- Inserts tags into the original text based on token offsets.
300
-
301
- Args:
302
- text (str): The original input text.
303
- offset_mapping (list of tuples): Maps token indices to character spans.
304
- start_cause (int): Start token index of the cause span.
305
- end_cause (int): End token index of the cause span.
306
- start_effect (int): Start token index of the effect span.
307
- end_effect (int): End token index of the effect span.
308
- start_signal (int, optional): Start token index of the signal span.
309
- end_signal (int, optional): End token index of the signal span.
310
-
311
- Returns:
312
- str: The modified text with annotated spans.
313
- """
314
-
315
- # Convert token indices to character indices
316
- spans = [
317
- (offset_mapping[start_cause][0], offset_mapping[end_cause][1], "<ARG0>", "</ARG0>"),
318
- (offset_mapping[start_effect][0], offset_mapping[end_effect][1], "<ARG1>", "</ARG1>")
319
- ]
320
-
321
- # Include signal tags if available
322
- if start_signal is not None and end_signal is not None:
323
- spans.append((offset_mapping[start_signal][0], offset_mapping[end_signal][1], "<SIG0>", "</SIG0>"))
324
-
325
- # Sort spans in reverse order based on start index (to avoid shifting issues)
326
- spans.sort(reverse=True, key=lambda x: x[0])
327
-
328
- # Insert tags
329
- for start, end, open_tag, close_tag in spans:
330
- text = text[:start] + open_tag + text[start:end] + close_tag + text[end:]
331
-
332
- return text
333
-
334
- import re
335
-
336
- def add_tags_offset_3(text, start_cause, end_cause, start_effect, end_effect, start_signal, end_signal):
337
- """
338
- Inserts tags into the original text based on token offsets, ensuring correct nesting,
339
- avoiding empty tags, preventing duplication, and handling punctuation placement.
340
-
341
- Args:
342
- text (str): The original input text.
343
- offset_mapping (list of tuples): Maps token indices to character spans.
344
- start_cause (int): Start token index of the cause span.
345
- end_cause (int): End token index of the cause span.
346
- start_effect (int): Start token index of the effect span.
347
- end_effect (int): End token index of the effect span.
348
- start_signal (int, optional): Start token index of the signal span.
349
- end_signal (int, optional): End token index of the signal span.
350
-
351
- Returns:
352
- str: The modified text with correctly positioned annotated spans.
353
- """
354
-
355
- # Convert token indices to character indices
356
- spans = []
357
-
358
- # Function to adjust start position to avoid punctuation issues
359
- def adjust_start(text, start):
360
- while start < len(text) and text[start] in {',', ' ', '.', ';', ':'}:
361
- start += 1 # Move past punctuation
362
- return start
363
-
364
- # Ensure valid spans (avoid empty tags)
365
- if start_cause is not None and end_cause is not None and start_cause < end_cause:
366
- start_cause_char, end_cause_char = offset_mapping[start_cause][0], offset_mapping[end_cause][1]
367
- spans.append((start_cause_char, end_cause_char, "<ARG0>", "</ARG0>"))
368
-
369
- if start_effect is not None and end_effect is not None and start_effect < end_effect:
370
- start_effect_char, end_effect_char = offset_mapping[start_effect][0], offset_mapping[end_effect][1]
371
- start_effect_char = adjust_start(text, start_effect_char) # Skip punctuation
372
- spans.append((start_effect_char, end_effect_char, "<ARG1>", "</ARG1>"))
373
-
374
- if start_signal is not None and end_signal is not None and start_signal < end_signal:
375
- start_signal_char, end_signal_char = offset_mapping[start_signal][0], offset_mapping[end_signal][1]
376
- spans.append((start_signal_char, end_signal_char, "<SIG0>", "</SIG0>"))
377
-
378
- # Sort spans in reverse order based on start index (to avoid shifting issues)
379
- spans.sort(reverse=True, key=lambda x: x[0])
380
-
381
- # Insert tags correctly
382
- modified_text = text
383
- inserted_positions = []
384
-
385
- for start, end, open_tag, close_tag in spans:
386
- # Adjust positions based on previous insertions
387
- shift = sum(len(tag) for pos, tag in inserted_positions if pos <= start)
388
- start += shift
389
- end += shift
390
-
391
- # Ensure valid start/end to prevent empty tags
392
- if start < end:
393
- modified_text = modified_text[:start] + open_tag + modified_text[start:end] + close_tag + modified_text[end:]
394
- inserted_positions.append((start, open_tag))
395
- inserted_positions.append((end + len(open_tag), close_tag))
396
-
397
- return modified_text
398
-
399
-
400
-
401
-
402
- tagged_sentence1 = add_tags_offset_3(input_text, start_cause1, end_cause1, start_effect1, end_effect1, start_signal, end_signal)
403
- tagged_sentence2 = add_tags_offset_3(input_text, start_cause2, end_cause2, start_effect2, end_effect2, start_signal, end_signal)
404
- return tagged_sentence1, tagged_sentence2
405
-
406
-
407
-
408
-
409
-
410
- def mark_text_by_position(original_text, start_idx, end_idx, color):
411
- """Marks text in the original string based on character positions."""
412
- if start_idx is not None and end_idx is not None and start_idx <= end_idx:
413
- return (
414
- original_text[:start_idx]
415
- + f"<mark style='background-color:{color}; padding:2px; border-radius:4px;'>"
416
- + original_text[start_idx:end_idx]
417
- + "</mark>"
418
- + original_text[end_idx:]
419
- )
420
- return original_text # Return unchanged if indices are invalidt # Return unchanged text if no span is found
421
-
422
- def mark_text_by_tokens(tokenizer, tokens, start_idx, end_idx, color):
423
- """Highlights a span in tokenized text using HTML."""
424
- highlighted_tokens = copy.deepcopy(tokens) # Avoid modifying original tokens
425
- if start_idx is not None and end_idx is not None and start_idx <= end_idx:
426
- highlighted_tokens[start_idx] = f"<span style='background-color:{color}; padding:2px; border-radius:4px;'>{highlighted_tokens[start_idx]}"
427
- highlighted_tokens[end_idx] = f"{highlighted_tokens[end_idx]}</span>"
428
- return tokenizer.convert_tokens_to_string(highlighted_tokens)
429
-
430
- def mark_text_by_word_ids(original_text, token_ids, start_word_id, end_word_id, color):
431
- """Marks words in the original text based on word IDs from tokenized input."""
432
- words = original_text.split() # Split text into words
433
- if start_word_id is not None and end_word_id is not None and start_word_id <= end_word_id:
434
- words[start_word_id] = f"<mark style='background-color:{color}; padding:2px; border-radius:4px;'>{words[start_word_id]}"
435
- words[end_word_id] = f"{words[end_word_id]}</mark>"
436
-
437
- return " ".join(words)
438
-
439
-
440
-
441
-
442
- st.title("Causal Relation Extraction")
443
- input_text = st.text_area("Enter your text here:", height=300)
444
- beam_search = st.radio("Enable Beam Search?", ('No', 'Yes')) == 'Yes'
445
-
446
- if st.button("Add Argument Tags"):
447
- if input_text:
448
- tagged_sentence1, tagged_sentence2 = extract_arguments(input_text, tokenizer, model, beam_search=True)
449
-
450
- st.write("**Tagged Sentence_1:**")
451
- st.write(tagged_sentence1)
452
- st.write("**Tagged Sentence_2:**")
453
- st.write(tagged_sentence2)
454
- else:
455
- st.warning("Please enter some text to analyze.")
456
-
457
-
458
- if st.button("Extract"):
459
- if input_text:
460
- start_cause_id, end_cause_id, start_effect_id, end_effect_id, start_signal_id, end_signal_id = extract_arguments(input_text, tokenizer, model, beam_search=beam_search)
461
-
462
- cause_text = mark_text_by_word_ids(input_text, inputs["input_ids"][0], start_cause_id, end_cause_id, "#FFD700") # Gold for cause
463
- effect_text = mark_text_by_word_ids(input_text, inputs["input_ids"][0], start_effect_id, end_effect_id, "#90EE90") # Light green for effect
464
- signal_text = mark_text_by_word_ids(input_text, inputs["input_ids"][0], start_signal_id, end_signal_id, "#FF6347") # Tomato red for signal
465
-
466
- st.markdown(f"**Cause:**<br>{cause_text}", unsafe_allow_html=True)
467
- st.markdown(f"**Effect:**<br>{effect_text}", unsafe_allow_html=True)
468
- st.markdown(f"**Signal:**<br>{signal_text}", unsafe_allow_html=True)
469
- else:
470
- st.warning("Please enter some text before extracting.")
471
-
472
-
473
-
474
-
475
- if st.button("Extract1"):
476
- if input_text:
477
- start_cause1, end_cause1, start_cause2, end_cause2, start_effect1, end_effect1, start_effect2, end_effect2, start_signal, end_signal = extract_arguments(input_text, tokenizer, model, beam_search=beam_search)
478
-
479
- # Convert text to tokenized format
480
- tokenized_input = tokenizer.tokenize(input_text)
481
-
482
- cause_text1 = mark_text_by_tokens(tokenizer, tokenized_input, start_cause1, end_cause1, "#FFD700") # Gold for cause
483
- effect_text1 = mark_text_by_tokens(tokenizer, tokenized_input, start_effect1, end_effect1, "#90EE90") # Light green for effect
484
- signal_text = mark_text_by_tokens(tokenizer, tokenized_input, start_signal, end_signal, "#FF6347") # Tomato red for signal
485
-
486
- # Display first relation
487
- st.markdown(f"<strong>Relation 1:</strong>", unsafe_allow_html=True)
488
- st.markdown(f"**Cause:** {cause_text1}", unsafe_allow_html=True)
489
- st.markdown(f"**Effect:** {effect_text1}", unsafe_allow_html=True)
490
- st.markdown(f"**Signal:** {signal_text}", unsafe_allow_html=True)
491
-
492
- # Display second relation if beam search is enabled
493
- if beam_search:
494
- cause_text2 = mark_text_by_tokens(tokenizer, tokenized_input, start_cause2, end_cause2, "#FFD700")
495
- effect_text2 = mark_text_by_tokens(tokenizer, tokenized_input, start_effect2, end_effect2, "#90EE90")
496
-
497
- st.markdown(f"<strong>Relation 2:</strong>", unsafe_allow_html=True)
498
- st.markdown(f"**Cause:** {cause_text2}", unsafe_allow_html=True)
499
- st.markdown(f"**Effect:** {effect_text2}", unsafe_allow_html=True)
500
- st.markdown(f"**Signal:** {signal_text}", unsafe_allow_html=True)
501
- else:
502
- st.warning("Please enter some text before extracting.")