Spaces:
Running
Running
Update app_8.py
Browse files
app_8.py
CHANGED
@@ -12,7 +12,6 @@ hf_token = st.secrets["HUGGINGFACE_TOKEN"]
|
|
12 |
login(token=hf_token)
|
13 |
|
14 |
|
15 |
-
# Load model & tokenizer once (cached for efficiency)
|
16 |
@st.cache_resource
|
17 |
def load_model():
|
18 |
|
@@ -35,20 +34,19 @@ def load_model():
|
|
35 |
repo_id = "anamargarida/SpanExtractionWithSignalCls_2"
|
36 |
filename = "model.safetensors"
|
37 |
|
38 |
-
|
39 |
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
40 |
|
41 |
-
# Load the model weights
|
42 |
state_dict = load_file(model_path)
|
43 |
|
44 |
model.load_state_dict(state_dict)
|
45 |
|
46 |
return tokenizer, model
|
47 |
|
48 |
-
|
49 |
tokenizer, model = load_model()
|
50 |
|
51 |
-
model.eval()
|
52 |
def extract_arguments(text, tokenizer, model, beam_search=True):
|
53 |
|
54 |
class Args:
|
@@ -59,14 +57,14 @@ def extract_arguments(text, tokenizer, model, beam_search=True):
|
|
59 |
args = Args()
|
60 |
inputs = tokenizer(text, return_offsets_mapping=True, return_tensors="pt")
|
61 |
|
62 |
-
|
63 |
word_ids = inputs.word_ids()
|
64 |
|
65 |
with torch.no_grad():
|
66 |
outputs = model(**inputs)
|
67 |
|
68 |
|
69 |
-
|
70 |
start_cause_logits = outputs["start_arg0_logits"][0]
|
71 |
end_cause_logits = outputs["end_arg0_logits"][0]
|
72 |
start_effect_logits = outputs["start_arg1_logits"][0]
|
@@ -164,11 +162,11 @@ def extract_arguments(text, tokenizer, model, beam_search=True):
|
|
164 |
|
165 |
|
166 |
|
167 |
-
cause_text1 = mark_text_by_position(input_text, start_cause1, end_cause1, "#FFD700") #
|
168 |
-
effect_text1 = mark_text_by_position(input_text, start_effect1, end_effect1, "#90EE90") #
|
169 |
|
170 |
if start_signal is not None and end_signal is not None:
|
171 |
-
signal_text = mark_text_by_position(input_text, start_signal, end_signal, "#FF6347") #
|
172 |
else:
|
173 |
signal_text = None
|
174 |
|
@@ -225,7 +223,7 @@ def extract_arguments(text, tokenizer, model, beam_search=True):
|
|
225 |
|
226 |
|
227 |
st.title("Causal Relation Extraction")
|
228 |
-
input_text = st.text_area("Enter your text here:", height=
|
229 |
beam_search = st.radio("Enable Position Selector & Beam Search?", ('Yes', 'No')) == 'Yes'
|
230 |
|
231 |
|
|
|
12 |
login(token=hf_token)
|
13 |
|
14 |
|
|
|
15 |
@st.cache_resource
|
16 |
def load_model():
|
17 |
|
|
|
34 |
repo_id = "anamargarida/SpanExtractionWithSignalCls_2"
|
35 |
filename = "model.safetensors"
|
36 |
|
37 |
+
|
38 |
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
39 |
|
|
|
40 |
state_dict = load_file(model_path)
|
41 |
|
42 |
model.load_state_dict(state_dict)
|
43 |
|
44 |
return tokenizer, model
|
45 |
|
46 |
+
|
47 |
tokenizer, model = load_model()
|
48 |
|
49 |
+
model.eval()
|
50 |
def extract_arguments(text, tokenizer, model, beam_search=True):
|
51 |
|
52 |
class Args:
|
|
|
57 |
args = Args()
|
58 |
inputs = tokenizer(text, return_offsets_mapping=True, return_tensors="pt")
|
59 |
|
60 |
+
|
61 |
word_ids = inputs.word_ids()
|
62 |
|
63 |
with torch.no_grad():
|
64 |
outputs = model(**inputs)
|
65 |
|
66 |
|
67 |
+
|
68 |
start_cause_logits = outputs["start_arg0_logits"][0]
|
69 |
end_cause_logits = outputs["end_arg0_logits"][0]
|
70 |
start_effect_logits = outputs["start_arg1_logits"][0]
|
|
|
162 |
|
163 |
|
164 |
|
165 |
+
cause_text1 = mark_text_by_position(input_text, start_cause1, end_cause1, "#FFD700") # yellow for cause
|
166 |
+
effect_text1 = mark_text_by_position(input_text, start_effect1, end_effect1, "#90EE90") # green for effect
|
167 |
|
168 |
if start_signal is not None and end_signal is not None:
|
169 |
+
signal_text = mark_text_by_position(input_text, start_signal, end_signal, "#FF6347") # red for signal
|
170 |
else:
|
171 |
signal_text = None
|
172 |
|
|
|
223 |
|
224 |
|
225 |
st.title("Causal Relation Extraction")
|
226 |
+
input_text = st.text_area("Enter your text here:", height=100)
|
227 |
beam_search = st.radio("Enable Position Selector & Beam Search?", ('Yes', 'No')) == 'Yes'
|
228 |
|
229 |
|