Spaces:
Running
Running
import sys | |
import re | |
from importlib.metadata import version | |
import evaluate | |
import polars as pl | |
import gradio as gr | |
from natsort import natsorted | |
# Load evaluators | |
wer = evaluate.load("wer") | |
cer = evaluate.load("cer") | |
# Config | |
concurrency_limit = 5 | |
title = "Evaluate ASR Outputs" | |
# https://www.tablesgenerator.com/markdown_tables | |
authors_table = """ | |
## Authors | |
Follow them on social networks and **contact** if you need any help or have any questions: | |
| <img src="https://avatars.githubusercontent.com/u/7875085?v=4" width="100"> **Yehor Smoliakov** | | |
|-------------------------------------------------------------------------------------------------| | |
| https://t.me/smlkw in Telegram | | |
| https://x.com/yehor_smoliakov at X | | |
| https://github.com/egorsmkv at GitHub | | |
| https://huggingface.co/Yehor at Hugging Face | | |
| or use [email protected] | | |
""".strip() | |
examples = [ | |
["evaluation_results.jsonl", True, False, False], | |
["evaluation_results_batch.jsonl", True, False, True], | |
] | |
description_head = f""" | |
# {title} | |
## Overview | |
Upload a JSONL file generated by the ASR model. | |
""".strip() | |
description_foot = f""" | |
{authors_table} | |
""".strip() | |
metrics_value = """ | |
Metrics will appear here. | |
""".strip() | |
tech_env = f""" | |
#### Environment | |
- Python: {sys.version} | |
""".strip() | |
tech_libraries = f""" | |
#### Libraries | |
- evaluate: {version("evaluate")} | |
- gradio: {version("gradio")} | |
- jiwer: {version("jiwer")} | |
- polars: {version("polars")} | |
""".strip() | |
def clean_value(x): | |
s = ( | |
x.replace("’", "'") | |
.strip() | |
.lower() | |
.replace(":", " ") | |
.replace(",", " ") | |
.replace(".", " ") | |
.replace("?", " ") | |
.replace("!", " ") | |
.replace("–", " ") | |
.replace("«", " ") | |
.replace("»", " ") | |
.replace("—", " ") | |
.replace("…", " ") | |
.replace("/", " ") | |
.replace("\\", " ") | |
.replace("(", " ") | |
.replace(")", " ") | |
.replace("́", "") | |
.replace('"', " ") | |
) | |
s = re.sub(r" +", " ", s) | |
s = s.strip() | |
# print(s) | |
return s | |
def inference(file_name, _clear_punctuation, _show_chars, _batch_mode): | |
if not file_name: | |
raise gr.Error("Please paste your JSON file.") | |
df = pl.read_ndjson(file_name) | |
required_columns = [ | |
"filename", | |
"inference_start", | |
"inference_end", | |
"inference_total", | |
"duration", | |
"reference", | |
"prediction", | |
] | |
required_columns_batch = [ | |
"inference_start", | |
"inference_end", | |
"inference_total", | |
"filenames", | |
"durations", | |
"references", | |
"predictions", | |
] | |
inference_seconds = df["inference_total"].sum() | |
if _batch_mode: | |
if not all(col in df.columns for col in required_columns_batch): | |
raise gr.Error( | |
f"Please provide a JSONL file with the following columns: {required_columns_batch}" | |
) | |
duration_seconds = 0 | |
for durations in df["durations"]: | |
duration_seconds += durations.sum() | |
rtf = inference_seconds / duration_seconds | |
references_batch = df["references"] | |
predictions_batch = df["predictions"] | |
predictions = [] | |
for prediction in predictions_batch: | |
if _clear_punctuation: | |
prediction = prediction.map_elements( | |
clean_value, return_dtype=pl.String | |
) | |
predictions.extend(prediction) | |
else: | |
predictions.extend(prediction) | |
references = [] | |
for reference in references_batch: | |
references.extend(reference) | |
else: | |
if not all(col in df.columns for col in required_columns): | |
raise gr.Error( | |
f"Please provide a JSONL file with the following columns: {required_columns}" | |
) | |
duration_seconds = df["duration"].sum() | |
rtf = inference_seconds / duration_seconds | |
references = df["reference"] | |
if _clear_punctuation: | |
predictions = df["prediction"].map_elements( | |
clean_value, return_dtype=pl.String | |
) | |
else: | |
predictions = df["prediction"] | |
n_predictions = len(predictions) | |
n_references = len(references) | |
# Evaluate | |
wer_value = round(wer.compute(predictions=predictions, references=references), 4) | |
cer_value = round(cer.compute(predictions=predictions, references=references), 4) | |
inference_time = inference_seconds | |
audio_duration = duration_seconds | |
rtf = inference_time / audio_duration | |
results = [] | |
results.append( | |
f"- Number of references / predictions: {n_references} / {n_predictions}" | |
) | |
results.append(f"") | |
results.append(f"- WER: {wer_value} metric, {round(wer_value * 100, 4)}%") | |
results.append(f"- CER: {cer_value} metric, {round(cer_value * 100, 4)}%") | |
results.append("") | |
results.append(f"- Accuracy on words: {round(100 - 100 * wer_value, 4)}%") | |
results.append(f"- Accuracy on chars: {round(100 - 100 * cer_value, 4)}%") | |
results.append("") | |
results.append( | |
f"- Inference time: {round(inference_time, 4)} seconds, {round(inference_time / 60, 4)} mins, {round(inference_time / 60 / 60, 4)} hours" | |
) | |
results.append( | |
f"- Audio duration: {round(audio_duration, 4)} seconds, {round(audio_duration / 60 / 60, 4)} hours" | |
) | |
results.append("") | |
results.append(f"- RTF: {round(rtf, 4)}") | |
if _show_chars: | |
all_chars = set() | |
for pred in predictions: | |
for c in pred: | |
all_chars.add(c) | |
sorted_chars = natsorted(list(all_chars)) | |
results.append("") | |
results.append(f"Chars in predictions:") | |
results.append(f"{sorted_chars}") | |
return "\n".join(results) | |
demo = gr.Blocks( | |
title=title, | |
analytics_enabled=False, | |
theme=gr.themes.Base(), | |
) | |
with demo: | |
gr.Markdown(description_head) | |
gr.Markdown("## Usage") | |
with gr.Row(): | |
with gr.Column(): | |
jsonl_file = gr.File(label="A JSONL file") | |
clear_punctuation = gr.Checkbox( | |
label="Clear punctuation, some chars and convert to lowercase", | |
) | |
show_chars = gr.Checkbox( | |
label="Show chars in predictions", | |
) | |
batch_mode = gr.Checkbox( | |
label="Use batch mode", | |
) | |
metrics = gr.Textbox( | |
label="Metrics", | |
placeholder=metrics_value, | |
show_copy_button=True, | |
) | |
gr.Button("Calculate").click( | |
inference, | |
concurrency_limit=concurrency_limit, | |
inputs=[jsonl_file, clear_punctuation, show_chars, batch_mode], | |
outputs=metrics, | |
) | |
with gr.Row(): | |
gr.Examples( | |
label="Choose an example", | |
inputs=[jsonl_file, clear_punctuation, show_chars, batch_mode], | |
examples=examples, | |
) | |
gr.Markdown(description_foot) | |
gr.Markdown("### Gradio app uses:") | |
gr.Markdown(tech_env) | |
gr.Markdown(tech_libraries) | |
if __name__ == "__main__": | |
demo.queue() | |
demo.launch() | |