Yehor commited on
Commit
8f2bb45
·
1 Parent(s): 1248b75

Add ability to calculate WER/CER values per each row

Browse files
Files changed (4) hide show
  1. app.py +97 -13
  2. justfile +5 -0
  3. requirements.txt +4 -0
  4. ruff.toml +2 -0
app.py CHANGED
@@ -1,14 +1,17 @@
1
  import sys
2
- import re
3
 
4
  from importlib.metadata import version
5
 
 
6
  import polars as pl
7
  import gradio as gr
 
8
 
9
- # Config
10
- concurrency_limit = 5
 
11
 
 
12
  title = "See ASR Outputs"
13
 
14
  # https://www.tablesgenerator.com/markdown_tables
@@ -27,8 +30,8 @@ Follow them on social networks and **contact** if you need any help or have any
27
  """.strip()
28
 
29
  examples = [
30
- ["evaluation_results.jsonl", False],
31
- ["evaluation_results_batch.jsonl", True],
32
  ]
33
 
34
  description_head = f"""
@@ -36,7 +39,7 @@ description_head = f"""
36
 
37
  ## Overview
38
 
39
- See generated JSONL files made by ASR models as a dataframe.
40
  """.strip()
41
 
42
  description_foot = f"""
@@ -57,17 +60,34 @@ tech_libraries = f"""
57
  #### Libraries
58
 
59
  - gradio: {version("gradio")}
 
 
60
  - polars: {version("polars")}
61
  """.strip()
62
 
63
 
64
- def inference(file_name, _batch_mode):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  if not file_name:
66
  raise gr.Error("Please paste your JSON file.")
67
 
68
  df = pl.read_ndjson(file_name)
69
 
70
-
71
  required_columns = [
72
  "filename",
73
  "inference_start",
@@ -105,9 +125,70 @@ def inference(file_name, _batch_mode):
105
  df = df.drop(["inference_start", "inference_end", "filename"])
106
 
107
  # round "inference_total" field to 2 decimal places
108
- df = df.with_columns(pl.col("inference_total").round(2))
 
109
 
110
- return df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
 
113
  demo = gr.Blocks(
@@ -134,18 +215,21 @@ with demo:
134
  label="Use batch mode",
135
  )
136
 
 
 
 
 
137
 
138
  gr.Button("Show").click(
139
  inference,
140
- concurrency_limit=concurrency_limit,
141
- inputs=[jsonl_file, batch_mode],
142
  outputs=df,
143
  )
144
 
145
  with gr.Row():
146
  gr.Examples(
147
  label="Choose an example",
148
- inputs=[jsonl_file, batch_mode],
149
  examples=examples,
150
  )
151
 
 
1
  import sys
 
2
 
3
  from importlib.metadata import version
4
 
5
+ import evaluate
6
  import polars as pl
7
  import gradio as gr
8
+ from joblib import Parallel, delayed
9
 
10
+ # Load evaluators
11
+ wer = evaluate.load("wer")
12
+ cer = evaluate.load("cer")
13
 
14
+ # Config
15
  title = "See ASR Outputs"
16
 
17
  # https://www.tablesgenerator.com/markdown_tables
 
30
  """.strip()
31
 
32
  examples = [
33
+ ["evaluation_results.jsonl", False, True],
34
+ ["evaluation_results_batch.jsonl", True, True],
35
  ]
36
 
37
  description_head = f"""
 
39
 
40
  ## Overview
41
 
42
+ See generated JSONL files made by ASR models as a dataframe. Also, this app calculates WER and CER metrics for each row.
43
  """.strip()
44
 
45
  description_foot = f"""
 
60
  #### Libraries
61
 
62
  - gradio: {version("gradio")}
63
+ - jiwer: {version("jiwer")}
64
+ - evaluate: {version("evaluate")}
65
  - polars: {version("polars")}
66
  """.strip()
67
 
68
 
69
+ def compute_wer(prediction, reference):
70
+ return round(wer.compute(predictions=[prediction], references=[reference]), 4)
71
+
72
+
73
+ def compute_cer(prediction, reference):
74
+ return round(cer.compute(predictions=[prediction], references=[reference]), 4)
75
+
76
+
77
+ def compute_batch_wer(predictions, references):
78
+ return round(wer.compute(predictions=predictions, references=references), 4)
79
+
80
+
81
+ def compute_batch_cer(predictions, references):
82
+ return round(cer.compute(predictions=predictions, references=references), 4)
83
+
84
+
85
+ def inference(file_name, _batch_mode, _calculate_metrics):
86
  if not file_name:
87
  raise gr.Error("Please paste your JSON file.")
88
 
89
  df = pl.read_ndjson(file_name)
90
 
 
91
  required_columns = [
92
  "filename",
93
  "inference_start",
 
125
  df = df.drop(["inference_start", "inference_end", "filename"])
126
 
127
  # round "inference_total" field to 2 decimal places
128
+ df = df.with_columns(pl.col("inference_total").round(2).alias("elapsed"))
129
+ df = df.drop(["inference_total"])
130
 
131
+ # reassign columns
132
+ if _batch_mode:
133
+ if _calculate_metrics:
134
+ wer_values = Parallel(n_jobs=-1)(
135
+ delayed(compute_batch_wer)(row["predictions"], row["references"])
136
+ for row in df.iter_rows(named=True)
137
+ )
138
+ cer_values = Parallel(n_jobs=-1)(
139
+ delayed(compute_batch_cer)(row["predictions"], row["references"])
140
+ for row in df.iter_rows(named=True)
141
+ )
142
+
143
+ df.insert_column(2, pl.Series("wer", wer_values))
144
+ df.insert_column(3, pl.Series("cer", cer_values))
145
+
146
+ fields = [
147
+ "elapsed",
148
+ "durations",
149
+ "wer",
150
+ "cer",
151
+ "predictions",
152
+ "references",
153
+ ]
154
+ else:
155
+ fields = [
156
+ "elapsed",
157
+ "durations",
158
+ "predictions",
159
+ "references",
160
+ ]
161
+ else:
162
+ if _calculate_metrics:
163
+ wer_values = Parallel(n_jobs=-1)(
164
+ delayed(compute_wer)(row["prediction"], row["reference"])
165
+ for row in df.iter_rows(named=True)
166
+ )
167
+ cer_values = Parallel(n_jobs=-1)(
168
+ delayed(compute_cer)(row["prediction"], row["reference"])
169
+ for row in df.iter_rows(named=True)
170
+ )
171
+
172
+ df.insert_column(2, pl.Series("wer", wer_values))
173
+ df.insert_column(3, pl.Series("cer", cer_values))
174
+
175
+ fields = [
176
+ "elapsed",
177
+ "duration",
178
+ "wer",
179
+ "cer",
180
+ "prediction",
181
+ "reference",
182
+ ]
183
+ else:
184
+ fields = [
185
+ "elapsed",
186
+ "duration",
187
+ "prediction",
188
+ "reference",
189
+ ]
190
+
191
+ return df.select(fields)
192
 
193
 
194
  demo = gr.Blocks(
 
215
  label="Use batch mode",
216
  )
217
 
218
+ calculate_metrics = gr.Checkbox(
219
+ label="Calculate WER/CER metrics",
220
+ value=True,
221
+ )
222
 
223
  gr.Button("Show").click(
224
  inference,
225
+ inputs=[jsonl_file, batch_mode, calculate_metrics],
 
226
  outputs=df,
227
  )
228
 
229
  with gr.Row():
230
  gr.Examples(
231
  label="Choose an example",
232
+ inputs=[jsonl_file, batch_mode, calculate_metrics],
233
  examples=examples,
234
  )
235
 
justfile ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ check:
2
+ ruff check
3
+
4
+ fmt: check
5
+ ruff format
requirements.txt CHANGED
@@ -1,3 +1,7 @@
1
  gradio==5.23.0
2
 
3
  polars==1.26.0
 
 
 
 
 
1
  gradio==5.23.0
2
 
3
  polars==1.26.0
4
+ evaluate==0.4.3
5
+ jiwer==3.1.0
6
+
7
+ joblib==1.4.2
ruff.toml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [lint]
2
+ ignore = ["F403"]