Yehor Smoliakov commited on
Commit
11961e0
·
1 Parent(s): 76c65b5

Add batch mode

Browse files
Files changed (2) hide show
  1. app.py +55 -13
  2. evaluation_results.jsonl +0 -0
app.py CHANGED
@@ -33,7 +33,8 @@ Follow them on social networks and **contact** if you need any help or have any
33
  """.strip()
34
 
35
  examples = [
36
- ["evaluation_results.jsonl", True, False],
 
37
  ]
38
 
39
  description_head = f"""
@@ -100,25 +101,58 @@ def clean_value(x):
100
  return s
101
 
102
 
103
- def inference(file_name, clear_punctuation, show_chars, progress=gr.Progress()):
104
  if not file_name:
105
  raise gr.Error("Please paste your JSON file.")
106
 
107
- progress(0, desc="Calculating...")
108
-
109
  df = pl.read_ndjson(file_name)
110
 
111
  inference_seconds = df["inference_total"].sum()
112
- duration_seconds = df["duration"].sum()
113
 
114
- rtf = inference_seconds / duration_seconds
 
 
 
 
 
 
 
 
 
 
115
 
116
- references = df["reference"]
 
117
 
118
- if clear_punctuation:
119
- predictions = df["prediction"].map_elements(clean_value, return_dtype=pl.String)
 
 
 
 
 
 
 
 
 
 
 
120
  else:
121
- predictions = df["prediction"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  # Evaluate
124
  wer_value = round(wer.compute(predictions=predictions, references=references), 4)
@@ -131,6 +165,10 @@ def inference(file_name, clear_punctuation, show_chars, progress=gr.Progress()):
131
 
132
  results = []
133
 
 
 
 
 
134
  results.append(f"- WER: {wer_value} metric, {round(wer_value * 100, 4)}%")
135
  results.append(f"- CER: {cer_value} metric, {round(cer_value * 100, 4)}%")
136
  results.append("")
@@ -146,7 +184,7 @@ def inference(file_name, clear_punctuation, show_chars, progress=gr.Progress()):
146
  results.append("")
147
  results.append(f"- RTF: {round(rtf, 4)}")
148
 
149
- if show_chars:
150
  all_chars = set()
151
  for pred in predictions:
152
  for c in pred:
@@ -175,12 +213,16 @@ with demo:
175
  with gr.Row():
176
  with gr.Column():
177
  jsonl_file = gr.File(label="A JSONL file")
 
178
  clear_punctuation = gr.Checkbox(
179
  label="Clear punctuation, some chars and convert to lowercase",
180
  )
181
  show_chars = gr.Checkbox(
182
  label="Show chars in predictions",
183
  )
 
 
 
184
 
185
  metrics = gr.Textbox(
186
  label="Metrics",
@@ -191,14 +233,14 @@ with demo:
191
  gr.Button("Calculate").click(
192
  inference,
193
  concurrency_limit=concurrency_limit,
194
- inputs=[jsonl_file, clear_punctuation, show_chars],
195
  outputs=metrics,
196
  )
197
 
198
  with gr.Row():
199
  gr.Examples(
200
  label="Choose an example",
201
- inputs=[jsonl_file, clear_punctuation, show_chars],
202
  examples=examples,
203
  )
204
 
 
33
  """.strip()
34
 
35
  examples = [
36
+ ["evaluation_results.jsonl", True, False, False],
37
+ ["evaluation_results_batch.jsonl", True, False, True],
38
  ]
39
 
40
  description_head = f"""
 
101
  return s
102
 
103
 
104
+ def inference(file_name, _clear_punctuation, _show_chars, _batch_mode):
105
  if not file_name:
106
  raise gr.Error("Please paste your JSON file.")
107
 
 
 
108
  df = pl.read_ndjson(file_name)
109
 
110
  inference_seconds = df["inference_total"].sum()
 
111
 
112
+ if _batch_mode:
113
+ if "durations" not in df.columns:
114
+ raise gr.Error(
115
+ "Please use a JSONL file with 'durations' column for batch mode."
116
+ )
117
+
118
+ duration_seconds = 0
119
+ for durations in df["durations"]:
120
+ duration_seconds += durations.sum()
121
+
122
+ rtf = inference_seconds / duration_seconds
123
 
124
+ references_batch = df["references"]
125
+ predictions_batch = df["predictions"]
126
 
127
+ predictions = []
128
+ for prediction in predictions_batch:
129
+ if _clear_punctuation:
130
+ prediction = prediction.map_elements(
131
+ clean_value, return_dtype=pl.String
132
+ )
133
+ predictions.extend(prediction)
134
+ else:
135
+ predictions.extend(prediction)
136
+
137
+ references = []
138
+ for reference in references_batch:
139
+ references.extend(reference)
140
  else:
141
+ duration_seconds = df["duration"].sum()
142
+
143
+ rtf = inference_seconds / duration_seconds
144
+
145
+ references = df["reference"]
146
+
147
+ if _clear_punctuation:
148
+ predictions = df["prediction"].map_elements(
149
+ clean_value, return_dtype=pl.String
150
+ )
151
+ else:
152
+ predictions = df["prediction"]
153
+
154
+ n_predictions = len(predictions)
155
+ n_references = len(references)
156
 
157
  # Evaluate
158
  wer_value = round(wer.compute(predictions=predictions, references=references), 4)
 
165
 
166
  results = []
167
 
168
+ results.append(
169
+ f"- Number of references / predictions: {n_references} / {n_predictions}"
170
+ )
171
+ results.append(f"")
172
  results.append(f"- WER: {wer_value} metric, {round(wer_value * 100, 4)}%")
173
  results.append(f"- CER: {cer_value} metric, {round(cer_value * 100, 4)}%")
174
  results.append("")
 
184
  results.append("")
185
  results.append(f"- RTF: {round(rtf, 4)}")
186
 
187
+ if _show_chars:
188
  all_chars = set()
189
  for pred in predictions:
190
  for c in pred:
 
213
  with gr.Row():
214
  with gr.Column():
215
  jsonl_file = gr.File(label="A JSONL file")
216
+
217
  clear_punctuation = gr.Checkbox(
218
  label="Clear punctuation, some chars and convert to lowercase",
219
  )
220
  show_chars = gr.Checkbox(
221
  label="Show chars in predictions",
222
  )
223
+ batch_mode = gr.Checkbox(
224
+ label="Use batch mode",
225
+ )
226
 
227
  metrics = gr.Textbox(
228
  label="Metrics",
 
233
  gr.Button("Calculate").click(
234
  inference,
235
  concurrency_limit=concurrency_limit,
236
+ inputs=[jsonl_file, clear_punctuation, show_chars, batch_mode],
237
  outputs=metrics,
238
  )
239
 
240
  with gr.Row():
241
  gr.Examples(
242
  label="Choose an example",
243
+ inputs=[jsonl_file, clear_punctuation, show_chars, batch_mode],
244
  examples=examples,
245
  )
246
 
evaluation_results.jsonl CHANGED
The diff for this file is too large to render. See raw diff