Yehor commited on
Commit
3b5c038
·
verified ·
1 Parent(s): 8ccd395

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -30
app.py CHANGED
@@ -1,10 +1,12 @@
1
  import sys
 
2
 
3
  from importlib.metadata import version
4
 
5
  import evaluate
6
  import polars as pl
7
  import gradio as gr
 
8
 
9
  # Load evaluators
10
  wer = evaluate.load("wer")
@@ -59,17 +61,45 @@ tech_env = f"""
59
  tech_libraries = f"""
60
  #### Libraries
61
 
62
- - evaluate: {version('evaluate')}
63
- - gradio: {version('gradio')}
64
- - jiwer: {version('jiwer')}
65
- - polars: {version('polars')}
66
  """.strip()
67
 
68
 
69
  def clean_value(x):
70
- return x.replace('’', "'").strip().lower().replace(',', '').replace('.', '').replace('?', '').replace('!', '').replace('–', '').replace('«', '').replace('»', '')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
-
73
  def inference(file_name, clear_punctuation, show_chars, progress=gr.Progress()):
74
  if not file_name:
75
  raise gr.Error("Please paste your JSON file.")
@@ -78,25 +108,21 @@ def inference(file_name, clear_punctuation, show_chars, progress=gr.Progress()):
78
 
79
  df = pl.read_ndjson(file_name)
80
 
81
- inference_seconds = df['inference_total'].sum()
82
- duration_seconds = df['duration'].sum()
83
 
84
  rtf = inference_seconds / duration_seconds
85
 
86
- references = df['reference']
87
 
88
  if clear_punctuation:
89
- predictions = df['prediction'].map_elements(clean_value)
90
  else:
91
- predictions = df['prediction']
92
 
93
  # Evaluate
94
- wer_value = round(
95
- wer.compute(predictions=predictions, references=references), 4
96
- )
97
- cer_value = round(
98
- cer.compute(predictions=predictions, references=references), 4
99
- )
100
 
101
  inference_time = inference_seconds
102
  audio_duration = duration_seconds
@@ -106,27 +132,33 @@ def inference(file_name, clear_punctuation, show_chars, progress=gr.Progress()):
106
  results = []
107
 
108
  results.append(f"Metrics using `evaluate` library:")
109
- results.append('')
110
- results.append(f"- WER: {wer_value} metric, {round(wer_value*100, 4)}%")
111
- results.append(f"- CER: {cer_value} metric, {round(cer_value*100, 4)}%")
112
- results.append('')
113
  results.append(f"- Accuracy on words: {round(100 - 100 * wer_value, 4)}%")
114
  results.append(f"- Accuracy on chars: {round(100 - 100 * cer_value, 4)}%")
115
- results.append('')
116
- results.append(f"- Inference time: {round(inference_time, 4)} seconds, {round(inference_time/60, 4)} mins, {round(inference_time/60/60, 4)} hours")
117
- results.append(f"- Audio duration: {round(audio_duration, 4)} seconds, {round(audio_duration/60/60, 4)} hours")
118
- results.append('')
 
 
 
 
119
  results.append(f"- RTF: {round(rtf, 4)}")
120
 
121
  if show_chars:
122
  all_chars = set()
123
- for pred in list(df['prediction']):
124
  for c in pred:
125
  all_chars.add(c)
126
 
127
- results.append('')
 
 
128
  results.append(f"Chars in predictions:")
129
- results.append(f"{list(all_chars)}")
130
 
131
  return "\n".join(results)
132
 
@@ -161,12 +193,16 @@ with demo:
161
  gr.Button("Calculate").click(
162
  inference,
163
  concurrency_limit=concurrency_limit,
164
- inputs=[jsonl_file, clear_punctuation],
165
  outputs=metrics,
166
  )
167
 
168
  with gr.Row():
169
- gr.Examples(label="Choose an example", inputs=[jsonl_file, clear_punctuation, show_chars], examples=examples)
 
 
 
 
170
 
171
  gr.Markdown(description_foot)
172
 
 
1
  import sys
2
+ import re
3
 
4
  from importlib.metadata import version
5
 
6
  import evaluate
7
  import polars as pl
8
  import gradio as gr
9
+ from natsort import natsorted
10
 
11
  # Load evaluators
12
  wer = evaluate.load("wer")
 
61
  tech_libraries = f"""
62
  #### Libraries
63
 
64
+ - evaluate: {version("evaluate")}
65
+ - gradio: {version("gradio")}
66
+ - jiwer: {version("jiwer")}
67
+ - polars: {version("polars")}
68
  """.strip()
69
 
70
 
71
  def clean_value(x):
72
+ s = (
73
+ x.replace("’", "'")
74
+ .strip()
75
+ .lower()
76
+ .replace(":", " ")
77
+ .replace(",", " ")
78
+ .replace(".", " ")
79
+ .replace("?", " ")
80
+ .replace("!", " ")
81
+ .replace("–", " ")
82
+ .replace("«", " ")
83
+ .replace("»", " ")
84
+ .replace("—", " ")
85
+ .replace("…", " ")
86
+ .replace("/", " ")
87
+ .replace("\\", " ")
88
+ .replace("(", " ")
89
+ .replace(")", " ")
90
+ .replace("́", "")
91
+ .replace('"', " ")
92
+ )
93
+
94
+ s = re.sub(r" +", " ", s)
95
+
96
+ s = s.strip()
97
+
98
+ # print(s)
99
+
100
+ return s
101
+
102
 
 
103
  def inference(file_name, clear_punctuation, show_chars, progress=gr.Progress()):
104
  if not file_name:
105
  raise gr.Error("Please paste your JSON file.")
 
108
 
109
  df = pl.read_ndjson(file_name)
110
 
111
+ inference_seconds = df["inference_total"].sum()
112
+ duration_seconds = df["duration"].sum()
113
 
114
  rtf = inference_seconds / duration_seconds
115
 
116
+ references = df["reference"]
117
 
118
  if clear_punctuation:
119
+ predictions = df["prediction"].map_elements(clean_value, return_dtype=pl.String)
120
  else:
121
+ predictions = df["prediction"]
122
 
123
  # Evaluate
124
+ wer_value = round(wer.compute(predictions=predictions, references=references), 4)
125
+ cer_value = round(cer.compute(predictions=predictions, references=references), 4)
 
 
 
 
126
 
127
  inference_time = inference_seconds
128
  audio_duration = duration_seconds
 
132
  results = []
133
 
134
  results.append(f"Metrics using `evaluate` library:")
135
+ results.append("")
136
+ results.append(f"- WER: {wer_value} metric, {round(wer_value * 100, 4)}%")
137
+ results.append(f"- CER: {cer_value} metric, {round(cer_value * 100, 4)}%")
138
+ results.append("")
139
  results.append(f"- Accuracy on words: {round(100 - 100 * wer_value, 4)}%")
140
  results.append(f"- Accuracy on chars: {round(100 - 100 * cer_value, 4)}%")
141
+ results.append("")
142
+ results.append(
143
+ f"- Inference time: {round(inference_time, 4)} seconds, {round(inference_time / 60, 4)} mins, {round(inference_time / 60 / 60, 4)} hours"
144
+ )
145
+ results.append(
146
+ f"- Audio duration: {round(audio_duration, 4)} seconds, {round(audio_duration / 60 / 60, 4)} hours"
147
+ )
148
+ results.append("")
149
  results.append(f"- RTF: {round(rtf, 4)}")
150
 
151
  if show_chars:
152
  all_chars = set()
153
+ for pred in predictions:
154
  for c in pred:
155
  all_chars.add(c)
156
 
157
+ sorted_chars = natsorted(list(all_chars))
158
+
159
+ results.append("")
160
  results.append(f"Chars in predictions:")
161
+ results.append(f"{sorted_chars}")
162
 
163
  return "\n".join(results)
164
 
 
193
  gr.Button("Calculate").click(
194
  inference,
195
  concurrency_limit=concurrency_limit,
196
+ inputs=[jsonl_file, clear_punctuation, show_chars],
197
  outputs=metrics,
198
  )
199
 
200
  with gr.Row():
201
+ gr.Examples(
202
+ label="Choose an example",
203
+ inputs=[jsonl_file, clear_punctuation, show_chars],
204
+ examples=examples,
205
+ )
206
 
207
  gr.Markdown(description_foot)
208