anamargarida commited on
Commit
ebc577a
·
verified ·
1 Parent(s): b990576

Rename app_7.py to app_8.py

Browse files
Files changed (1) hide show
  1. app_7.py → app_8.py +67 -11
app_7.py → app_8.py RENAMED
@@ -88,7 +88,7 @@ def extract_arguments(text, tokenizer, model, beam_search=True):
88
 
89
  # Beam Search for position selection
90
  if beam_search:
91
- indices1, indices2, _, _, _ = model.beam_search_position_selector(
92
  start_cause_logits=start_cause_logits,
93
  end_cause_logits=end_cause_logits,
94
  start_effect_logits=start_effect_logits,
@@ -178,19 +178,63 @@ def extract_arguments(text, tokenizer, model, beam_search=True):
178
  else:
179
  cause_text2 = None
180
  effect_text2 = None
181
- return cause_text1, effect_text1, signal_text, cause_text2, effect_text2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  st.title("Causal Relation Extraction")
184
  input_text = st.text_area("Enter your text here:", height=300)
185
- beam_search = st.radio("Enable Position Selector & Beam Search?", ('No', 'Yes')) == 'Yes'
186
 
187
 
188
  if st.button("Extract"):
189
  if input_text:
190
- cause_text1, effect_text1, signal_text, cause_text2, effect_text2 = extract_arguments(input_text, tokenizer, model, beam_search=beam_search)
191
 
192
  # Display first relation
193
- st.markdown(f"<strong>Relation 1:</strong>", unsafe_allow_html=True)
194
 
195
  if cause_text1 is None or effect_text1 is None:
196
  st.write("The prediction is not correct for at least one span: The position of the predicted end token comes before the position of the start token.")
@@ -199,14 +243,26 @@ if st.button("Extract"):
199
  st.markdown(f"**Effect:** {effect_text1}", unsafe_allow_html=True)
200
  st.markdown(f"**Signal:** {signal_text}", unsafe_allow_html=True)
201
 
 
 
 
 
 
 
202
 
203
- # Display second relation if beam search is enabled
204
- if beam_search:
 
 
 
 
 
 
205
 
206
- st.markdown(f"<strong>Relation 2:</strong>", unsafe_allow_html=True)
207
- st.markdown(f"**Cause:** {cause_text2}", unsafe_allow_html=True)
208
- st.markdown(f"**Effect:** {effect_text2}", unsafe_allow_html=True)
209
- st.markdown(f"**Signal:** {signal_text}", unsafe_allow_html=True)
210
 
211
  else:
212
  st.warning("Please enter some text before extracting.")
 
88
 
89
  # Beam Search for position selection
90
  if beam_search:
91
+ indices1, indices2, score1, score2, topk_scores = model.beam_search_position_selector(
92
  start_cause_logits=start_cause_logits,
93
  end_cause_logits=end_cause_logits,
94
  start_effect_logits=start_effect_logits,
 
178
  else:
179
  cause_text2 = None
180
  effect_text2 = None
181
+
182
+ if beam_search:
183
+ start_cause_probs = torch.softmax(start_cause_logits, dim=-1)
184
+ end_cause_probs = torch.softmax(end_cause_logits, dim=-1)
185
+ start_effect_probs = torch.softmax(start_effect_logits, dim=-1)
186
+ end_effect_probs = torch.softmax(end_effect_logits, dim=-1)
187
+
188
+ best_start_cause_score = start_cause_probs[start_cause1].item()
189
+ best_end_cause_score = end_cause_probs[end_cause1].item()
190
+ best_start_effect_score = start_effect_probs[start_effect1].item()
191
+ best_end_effect_score = end_effect_probs[end_effect1].item()
192
+
193
+ second_start_cause_score = start_cause_probs[start_cause2].item()
194
+ second_end_cause_score = end_cause_probs[end_cause2].item()
195
+ second_start_effect_score = start_effect_probs[start_effect2].item()
196
+ second_end_effect_score = end_effect_probs[end_effect2].item()
197
+
198
+ best_scores = {
199
+ "Start Cause Score": round(best_start_cause_score, 4),
200
+ "End Cause Score": round(best_end_cause_score, 4),
201
+ "Start Effect Score": round(best_start_effect_score, 4),
202
+ "End Effect Score": round(best_end_effect_score, 4),
203
+ "Total Best Score [sum of log-probability scores]": round(score1, 4)
204
+ }
205
+
206
+ second_best_scores = {
207
+ "Start Cause Score": round(second_start_cause_score, 4),
208
+ "End Cause Score": round(second_end_cause_score, 4),
209
+ "Start Effect Score": round(second_start_effect_score, 4),
210
+ "End Effect Score": round(second_end_effect_score, 4),
211
+ "Total Second Best Score [sum of log-probability scores]": round(score2, 4)
212
+ }
213
+
214
+ top5_scores = sorted(topk_scores.items(), key=lambda x: x[1], reverse=True)[:5]
215
+ top5_scores = [(k, round(v, 4)) for k, v in top5_scores]
216
+
217
+
218
+ else:
219
+ best_scores = {}
220
+ second_best_scores = {}
221
+ top5_scores = {}
222
+
223
+
224
+ return cause_text1, effect_text1, signal_text, cause_text2, effect_text2, best_scores, second_best_scores, top5_scores
225
+
226
 
227
  st.title("Causal Relation Extraction")
228
  input_text = st.text_area("Enter your text here:", height=300)
229
+ beam_search = st.radio("Enable Position Selector & Beam Search?", ('Yes', 'No')) == 'Yes'
230
 
231
 
232
  if st.button("Extract"):
233
  if input_text:
234
+ cause_text1, effect_text1, signal_text, cause_text2, effect_text2, best_scores, second_best_scores, top5_scores = extract_arguments(input_text, tokenizer, model, beam_search=beam_search)
235
 
236
  # Display first relation
237
+ st.write("## Relation 1:")
238
 
239
  if cause_text1 is None or effect_text1 is None:
240
  st.write("The prediction is not correct for at least one span: The position of the predicted end token comes before the position of the start token.")
 
243
  st.markdown(f"**Effect:** {effect_text1}", unsafe_allow_html=True)
244
  st.markdown(f"**Signal:** {signal_text}", unsafe_allow_html=True)
245
 
246
+ if beam_search:
247
+
248
+ # Display dictionary in Streamlit
249
+ st.markdown(f"<strong>Best Tuple Scores:</strong>", unsafe_allow_html=True)
250
+ st.json(best_scores)
251
+
252
 
253
+ # Display second relation if beam search is enabled
254
+ st.write("## Relation 2:")
255
+ st.markdown(f"**Cause:** {cause_text2}", unsafe_allow_html=True)
256
+ st.markdown(f"**Effect:** {effect_text2}", unsafe_allow_html=True)
257
+ st.markdown(f"**Signal:** {signal_text}", unsafe_allow_html=True)
258
+
259
+ st.markdown(f"<strong>Second best Tuple Scores:</strong>", unsafe_allow_html=True)
260
+ st.json(second_best_scores)
261
 
262
+ st.markdown(f"<strong>top5_scores [sum of log-probability scores]:</strong>", unsafe_allow_html=True)
263
+ # Unpack top 5 scores
264
+ # first, second, third, fourth, fifth = top_5_scores
265
+ st.json(top5_scores)
266
 
267
  else:
268
  st.warning("Please enter some text before extracting.")