agiats commited on
Commit
26a0a47
·
1 Parent(s): 72953cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -1,5 +1,5 @@
1
  from pathlib import Path
2
- from typing import List, Tuple
3
 
4
  import gradio as gr
5
  import numpy as np
@@ -113,11 +113,14 @@ def highlight_token(token: str, score: float):
113
 
114
 
115
  def create_highlighted_text(
116
- label: str, tokens2scores: List[Tuple[str, float]], mean_surprisal: float
 
 
117
  ):
118
- highlighted_text: str = (
119
- "<h2><b>" + label + f"</b>(サプライザル平均値: {mean_surprisal})</h2>"
120
- )
 
121
  for token, score in tokens2scores:
122
  highlighted_text += highlight_token(token, score)
123
  return highlighted_text
@@ -184,7 +187,9 @@ def main(input_text: str) -> Tuple[str, str, str]:
184
  diff_tokens2surprisal = calculate_surprisal_diff(
185
  tokens2surprisal, baseline_tokens2surprisal, 100.0
186
  )
187
- diff_highlighted_text = create_highlighted_text("学習前後の差分", diff_tokens2surprisal)
 
 
188
  return (
189
  baseline_highlighted_text,
190
  highlighted_text,
 
1
  from pathlib import Path
2
+ from typing import List, Optional, Tuple
3
 
4
  import gradio as gr
5
  import numpy as np
 
113
 
114
 
115
  def create_highlighted_text(
116
+ label: str,
117
+ tokens2scores: List[Tuple[str, float]],
118
+ mean_surprisal: Optional[float] = None,
119
  ):
120
+ if mean_surprisal is None:
121
+ highlighted_text = "<h2><b>" + label + "</b></h2>"
122
+ else:
123
+ highlighted_text = "<h2><b>" + label + f"</b>(サプライザル平均値: {mean_surprisal})</h2>"
124
  for token, score in tokens2scores:
125
  highlighted_text += highlight_token(token, score)
126
  return highlighted_text
 
187
  diff_tokens2surprisal = calculate_surprisal_diff(
188
  tokens2surprisal, baseline_tokens2surprisal, 100.0
189
  )
190
+ diff_highlighted_text = create_highlighted_text(
191
+ "学習前後の差分", diff_tokens2surprisal, None
192
+ )
193
  return (
194
  baseline_highlighted_text,
195
  highlighted_text,