anamargarida commited on
Commit
24f27aa
·
verified ·
1 Parent(s): 1f69012

Update app_8.py

Browse files
Files changed (1) hide show
  1. app_8.py +9 -11
app_8.py CHANGED
@@ -12,7 +12,6 @@ hf_token = st.secrets["HUGGINGFACE_TOKEN"]
12
  login(token=hf_token)
13
 
14
 
15
- # Load model & tokenizer once (cached for efficiency)
16
  @st.cache_resource
17
  def load_model():
18
 
@@ -35,20 +34,19 @@ def load_model():
35
  repo_id = "anamargarida/SpanExtractionWithSignalCls_2"
36
  filename = "model.safetensors"
37
 
38
- # Download the model file
39
  model_path = hf_hub_download(repo_id=repo_id, filename=filename)
40
 
41
- # Load the model weights
42
  state_dict = load_file(model_path)
43
 
44
  model.load_state_dict(state_dict)
45
 
46
  return tokenizer, model
47
 
48
- # Load the model and tokenizer
49
  tokenizer, model = load_model()
50
 
51
- model.eval() # Set model to evaluation mode
52
  def extract_arguments(text, tokenizer, model, beam_search=True):
53
 
54
  class Args:
@@ -59,14 +57,14 @@ def extract_arguments(text, tokenizer, model, beam_search=True):
59
  args = Args()
60
  inputs = tokenizer(text, return_offsets_mapping=True, return_tensors="pt")
61
 
62
- # Get tokenized words (for reconstruction later)
63
  word_ids = inputs.word_ids()
64
 
65
  with torch.no_grad():
66
  outputs = model(**inputs)
67
 
68
 
69
- # Extract logits
70
  start_cause_logits = outputs["start_arg0_logits"][0]
71
  end_cause_logits = outputs["end_arg0_logits"][0]
72
  start_effect_logits = outputs["start_arg1_logits"][0]
@@ -164,11 +162,11 @@ def extract_arguments(text, tokenizer, model, beam_search=True):
164
 
165
 
166
 
167
- cause_text1 = mark_text_by_position(input_text, start_cause1, end_cause1, "#FFD700") # Gold for cause
168
- effect_text1 = mark_text_by_position(input_text, start_effect1, end_effect1, "#90EE90") # Light green for effect
169
 
170
  if start_signal is not None and end_signal is not None:
171
- signal_text = mark_text_by_position(input_text, start_signal, end_signal, "#FF6347") # Tomato red for signal
172
  else:
173
  signal_text = None
174
 
@@ -225,7 +223,7 @@ def extract_arguments(text, tokenizer, model, beam_search=True):
225
 
226
 
227
  st.title("Causal Relation Extraction")
228
- input_text = st.text_area("Enter your text here:", height=300)
229
  beam_search = st.radio("Enable Position Selector & Beam Search?", ('Yes', 'No')) == 'Yes'
230
 
231
 
 
12
  login(token=hf_token)
13
 
14
 
 
15
  @st.cache_resource
16
  def load_model():
17
 
 
34
  repo_id = "anamargarida/SpanExtractionWithSignalCls_2"
35
  filename = "model.safetensors"
36
 
37
+
38
  model_path = hf_hub_download(repo_id=repo_id, filename=filename)
39
 
 
40
  state_dict = load_file(model_path)
41
 
42
  model.load_state_dict(state_dict)
43
 
44
  return tokenizer, model
45
 
46
+
47
  tokenizer, model = load_model()
48
 
49
+ model.eval()
50
  def extract_arguments(text, tokenizer, model, beam_search=True):
51
 
52
  class Args:
 
57
  args = Args()
58
  inputs = tokenizer(text, return_offsets_mapping=True, return_tensors="pt")
59
 
60
+
61
  word_ids = inputs.word_ids()
62
 
63
  with torch.no_grad():
64
  outputs = model(**inputs)
65
 
66
 
67
+
68
  start_cause_logits = outputs["start_arg0_logits"][0]
69
  end_cause_logits = outputs["end_arg0_logits"][0]
70
  start_effect_logits = outputs["start_arg1_logits"][0]
 
162
 
163
 
164
 
165
+ cause_text1 = mark_text_by_position(input_text, start_cause1, end_cause1, "#FFD700") # yellow for cause
166
+ effect_text1 = mark_text_by_position(input_text, start_effect1, end_effect1, "#90EE90") # green for effect
167
 
168
  if start_signal is not None and end_signal is not None:
169
+ signal_text = mark_text_by_position(input_text, start_signal, end_signal, "#FF6347") # red for signal
170
  else:
171
  signal_text = None
172
 
 
223
 
224
 
225
  st.title("Causal Relation Extraction")
226
+ input_text = st.text_area("Enter your text here:", height=100)
227
  beam_search = st.radio("Enable Position Selector & Beam Search?", ('Yes', 'No')) == 'Yes'
228
 
229