Spaces:
Running
Running
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
from safetensors.torch import load_file
|
4 |
+
from transformers import AutoConfig, AutoTokenizer, AutoModel
|
5 |
+
from ST2ModelV2_6 import ST2ModelV2
|
6 |
+
from huggingface_hub import login
|
7 |
+
import re
|
8 |
+
import copy
|
9 |
+
|
10 |
+
hf_token = st.secrets["HUGGINGFACE_TOKEN"]
|
11 |
+
login(token=hf_token)
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
# Load model & tokenizer once (cached for efficiency)
|
16 |
+
@st.cache_resource
|
17 |
+
def load_model():
|
18 |
+
|
19 |
+
model_name = "anamargarida/Final"
|
20 |
+
|
21 |
+
config = AutoConfig.from_pretrained(model_name)
|
22 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
23 |
+
|
24 |
+
class Args:
|
25 |
+
def __init__(self):
|
26 |
+
self.model_name = model_name
|
27 |
+
self.dropout = 0.1
|
28 |
+
self.signal_classification = True
|
29 |
+
self.pretrained_signal_detector = False
|
30 |
+
|
31 |
+
args = Args()
|
32 |
+
|
33 |
+
# Load the model directly from Hugging Face
|
34 |
+
model = ST2ModelV2.from_pretrained(model_name, config=config, args=args)
|
35 |
+
|
36 |
+
|
37 |
+
return tokenizer, model
|
38 |
+
|
39 |
+
# Load the model and tokenizer
|
40 |
+
tokenizer, model = load_model()
|
41 |
+
|
42 |
+
st.write("model_", model)
|
43 |
+
st.write("model_weights", model.model)
|
44 |
+
st.write("config", model.config)
|
45 |
+
st.write("Signal_classifier_weights", model.signal_classifier.weight)
|
46 |
+
st.write(model.model.embeddings.LayerNorm.weight)
|
47 |
+
#st.write(model.model.encoder.layer.13.attention.self.value.weight)
|
48 |
+
roberta_model = AutoModel.from_pretrained("roberta-large")
|
49 |
+
st.write(roberta_model.embeddings.LayerNorm.weight)
|
50 |
+
|
51 |
+
model.eval() # Set model to evaluation mode
|
52 |
+
def extract_arguments(text, tokenizer, model, beam_search=True):
|
53 |
+
|
54 |
+
class Args:
|
55 |
+
def __init__(self):
|
56 |
+
self.signal_classification = True
|
57 |
+
self.pretrained_signal_detector = False
|
58 |
+
|
59 |
+
args = Args()
|
60 |
+
inputs = tokenizer(text, return_tensors="pt")
|
61 |
+
|
62 |
+
with torch.no_grad():
|
63 |
+
outputs = model(**inputs)
|
64 |
+
|
65 |
+
#st.write("Model output keys:", outputs.keys())
|
66 |
+
|
67 |
+
# Extract logits
|
68 |
+
start_cause_logits = outputs["start_arg0_logits"][0]
|
69 |
+
end_cause_logits = outputs["end_arg0_logits"][0]
|
70 |
+
start_effect_logits = outputs["start_arg1_logits"][0]
|
71 |
+
end_effect_logits = outputs["end_arg1_logits"][0]
|
72 |
+
start_signal_logits = outputs["start_sig_logits"][0]
|
73 |
+
end_signal_logits = outputs["end_sig_logits"][0]
|
74 |
+
|
75 |
+
#st.write("start_cause_logits", start_cause_logits)
|
76 |
+
#st.write("end_cause_logits", end_cause_logits)
|
77 |
+
#st.write("start_effect_logits", start_effect_logits)
|
78 |
+
#st.write("end_effect_logits", end_effect_logits)
|
79 |
+
#st.write("start_signal_logits", start_signal_logits)
|
80 |
+
#st.write("end_signal_logits", end_signal_logits)
|
81 |
+
|
82 |
+
|
83 |
+
# Set the first and last token logits to a very low value to ignore them
|
84 |
+
start_cause_logits[0] = -1e-4
|
85 |
+
end_cause_logits[0] = -1e-4
|
86 |
+
start_effect_logits[0] = -1e-4
|
87 |
+
end_effect_logits[0] = -1e-4
|
88 |
+
start_cause_logits[len(inputs["input_ids"][0]) - 1] = -1e-4
|
89 |
+
end_cause_logits[len(inputs["input_ids"][0]) - 1] = -1e-4
|
90 |
+
start_effect_logits[len(inputs["input_ids"][0]) - 1] = -1e-4
|
91 |
+
end_effect_logits[len(inputs["input_ids"][0]) - 1] = -1e-4
|
92 |
+
|
93 |
+
st.write("start_cause_logits", start_cause_logits)
|
94 |
+
st.write("end_cause_logits", end_cause_logits)
|
95 |
+
st.write("start_effect_logits", start_effect_logits)
|
96 |
+
st.write("end_effect_logits", end_effect_logits)
|
97 |
+
st.write("start_signal_logits", start_signal_logits)
|
98 |
+
st.write("end_signal_logits", end_signal_logits)
|
99 |
+
|
100 |
+
# Beam Search for position selection
|
101 |
+
if beam_search:
|
102 |
+
indices1, indices2, _, _, _ = model.beam_search_position_selector(
|
103 |
+
start_cause_logits=start_cause_logits,
|
104 |
+
end_cause_logits=end_cause_logits,
|
105 |
+
start_effect_logits=start_effect_logits,
|
106 |
+
end_effect_logits=end_effect_logits,
|
107 |
+
topk=5
|
108 |
+
)
|
109 |
+
start_cause1, end_cause1, start_effect1, end_effect1 = indices1
|
110 |
+
start_cause2, end_cause2, start_effect2, end_effect2 = indices2
|
111 |
+
else:
|
112 |
+
start_cause1 = start_cause_logits.argmax().item()
|
113 |
+
end_cause1 = end_cause_logits.argmax().item()
|
114 |
+
start_effect1 = start_effect_logits.argmax().item()
|
115 |
+
end_effect1 = end_effect_logits.argmax().item()
|
116 |
+
|
117 |
+
start_cause2, end_cause2, start_effect2, end_effect2 = None, None, None, None
|
118 |
+
|
119 |
+
|
120 |
+
has_signal = 1
|
121 |
+
if args.signal_classification:
|
122 |
+
if not args.pretrained_signal_detector:
|
123 |
+
has_signal = outputs["signal_classification_logits"].argmax().item()
|
124 |
+
else:
|
125 |
+
has_signal = signal_detector.predict(text=batch["text"])
|
126 |
+
|
127 |
+
if has_signal:
|
128 |
+
start_signal_logits[0] = -1e-4
|
129 |
+
end_signal_logits[0] = -1e-4
|
130 |
+
|
131 |
+
start_signal_logits[len(inputs["input_ids"][0]) - 1] = -1e-4
|
132 |
+
end_signal_logits[len(inputs["input_ids"][0]) - 1] = -1e-4
|
133 |
+
|
134 |
+
start_signal = start_signal_logits.argmax().item()
|
135 |
+
end_signal_logits[:start_signal] = -1e4
|
136 |
+
end_signal_logits[start_signal + 5:] = -1e4
|
137 |
+
end_signal = end_signal_logits.argmax().item()
|
138 |
+
|
139 |
+
if not has_signal:
|
140 |
+
start_signal = 'NA'
|
141 |
+
end_signal = 'NA'
|
142 |
+
|
143 |
+
|
144 |
+
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
|
145 |
+
token_ids = inputs["input_ids"][0]
|
146 |
+
|
147 |
+
#st.write("Token Positions, IDs, and Corresponding Tokens:")
|
148 |
+
#for position, (token_id, token) in enumerate(zip(token_ids, tokens)):
|
149 |
+
#st.write(f"Position: {position}, ID: {token_id}, Token: {token}")
|
150 |
+
|
151 |
+
st.write(f"Start Cause 1: {start_cause1}, End Cause: {end_cause1}")
|
152 |
+
st.write(f"Start Effect 1: {start_effect1}, End Cause: {end_effect1}")
|
153 |
+
st.write(f"Start Signal: {start_signal}, End Signal: {end_signal}")
|
154 |
+
|
155 |
+
def extract_span(start, end):
|
156 |
+
return tokenizer.convert_tokens_to_string(tokens[start:end+1]) if start is not None and end is not None else ""
|
157 |
+
|
158 |
+
cause1 = extract_span(start_cause1, end_cause1)
|
159 |
+
cause2 = extract_span(start_cause2, end_cause2)
|
160 |
+
effect1 = extract_span(start_effect1, end_effect1)
|
161 |
+
effect2 = extract_span(start_effect2, end_effect2)
|
162 |
+
if has_signal:
|
163 |
+
signal = extract_span(start_signal, end_signal)
|
164 |
+
if not has_signal:
|
165 |
+
signal = 'NA'
|
166 |
+
list1 = [start_cause1, end_cause1, start_effect1, end_effect1, start_signal, end_signal]
|
167 |
+
list2 = [start_cause2, end_cause2, start_effect2, end_effect2, start_signal, end_signal]
|
168 |
+
return cause1, cause2, effect1, effect2, signal, list1, list2
|
169 |
+
|
170 |
+
def mark_text(original_text, span, color):
|
171 |
+
"""Replace extracted span with a colored background marker."""
|
172 |
+
if span:
|
173 |
+
return re.sub(re.escape(span), f"<mark style='background-color:{color}; padding:2px; border-radius:4px;'>{span}</mark>", original_text, flags=re.IGNORECASE)
|
174 |
+
return original_text # Return unchanged text if no span is found
|
175 |
+
|
176 |
+
st.title("Causal Relation Extraction")
|
177 |
+
input_text = st.text_area("Enter your text here:", height=300)
|
178 |
+
beam_search = st.radio("Enable Beam Search?", ('No', 'Yes')) == 'Yes'
|
179 |
+
|
180 |
+
|
181 |
+
if st.button("Extract1"):
|
182 |
+
if input_text:
|
183 |
+
cause1, cause2, effect1, effect2, signal, list1, list2 = extract_arguments(input_text, tokenizer, model, beam_search=beam_search)
|
184 |
+
|
185 |
+
cause_text1 = mark_text(input_text, cause1, "#FFD700") # Gold for cause
|
186 |
+
effect_text1 = mark_text(input_text, effect1, "#90EE90") # Light green for effect
|
187 |
+
signal_text = mark_text(input_text, signal, "#FF6347") # Tomato red for signal
|
188 |
+
|
189 |
+
st.markdown(f"<span style='font-size: 24px;'><strong>Relation 1:</strong></span>", unsafe_allow_html=True)
|
190 |
+
st.markdown(f"**Cause:**<br>{cause_text1}", unsafe_allow_html=True)
|
191 |
+
st.markdown(f"**Effect:**<br>{effect_text1}", unsafe_allow_html=True)
|
192 |
+
st.markdown(f"**Signal:**<br>{signal_text}", unsafe_allow_html=True)
|
193 |
+
|
194 |
+
#st.write("List 1:", list1)
|
195 |
+
|
196 |
+
if beam_search:
|
197 |
+
|
198 |
+
cause_text2 = mark_text(input_text, cause2, "#FFD700") # Gold for cause
|
199 |
+
effect_text2 = mark_text(input_text, effect2, "#90EE90") # Light green for effect
|
200 |
+
signal_text = mark_text(input_text, signal, "#FF6347") # Tomato red for signal
|
201 |
+
|
202 |
+
st.markdown(f"<span style='font-size: 24px;'><strong>Relation 2:</strong></span>", unsafe_allow_html=True)
|
203 |
+
st.markdown(f"**Cause:**<br>{cause_text2}", unsafe_allow_html=True)
|
204 |
+
st.markdown(f"**Effect:**<br>{effect_text2}", unsafe_allow_html=True)
|
205 |
+
st.markdown(f"**Signal:**<br>{signal_text}", unsafe_allow_html=True)
|
206 |
+
|
207 |
+
#st.write("List 2:", list2)
|
208 |
+
else:
|
209 |
+
st.warning("Please enter some text before extracting.")
|