Spaces:
Sleeping
Sleeping
import sys | |
from importlib.metadata import version | |
import evaluate | |
import polars as pl | |
import gradio as gr | |
from joblib import Parallel, delayed | |
# Load evaluators | |
wer = evaluate.load("wer") | |
cer = evaluate.load("cer") | |
# Config | |
title = "See ASR Outputs" | |
# https://www.tablesgenerator.com/markdown_tables | |
authors_table = """ | |
## Authors | |
Follow them on social networks and **contact** if you need any help or have any questions: | |
| <img src="https://avatars.githubusercontent.com/u/7875085?v=4" width="100"> **Yehor Smoliakov** | | |
|-------------------------------------------------------------------------------------------------| | |
| https://t.me/smlkw in Telegram | | |
| https://x.com/yehor_smoliakov at X | | |
| https://github.com/egorsmkv at GitHub | | |
| https://huggingface.co/Yehor at Hugging Face | | |
| or use [email protected] | | |
""".strip() | |
examples = [ | |
["evaluation_results.jsonl", False, True], | |
["evaluation_results_batch.jsonl", True, True], | |
] | |
description_head = f""" | |
# {title} | |
## Overview | |
See generated JSONL files made by ASR models as a dataframe. Also, this app calculates WER and CER metrics for each row. | |
""".strip() | |
description_foot = f""" | |
{authors_table} | |
""".strip() | |
metrics_value = """ | |
Metrics will appear here. | |
""".strip() | |
tech_env = f""" | |
#### Environment | |
- Python: {sys.version} | |
""".strip() | |
tech_libraries = f""" | |
#### Libraries | |
- gradio: {version("gradio")} | |
- jiwer: {version("jiwer")} | |
- evaluate: {version("evaluate")} | |
- polars: {version("polars")} | |
""".strip() | |
def compute_wer(prediction, reference): | |
return round(wer.compute(predictions=[prediction], references=[reference]), 4) | |
def compute_cer(prediction, reference): | |
return round(cer.compute(predictions=[prediction], references=[reference]), 4) | |
def compute_batch_wer(predictions, references): | |
return round(wer.compute(predictions=predictions, references=references), 4) | |
def compute_batch_cer(predictions, references): | |
return round(cer.compute(predictions=predictions, references=references), 4) | |
def inference(file_name, _batch_mode, _calculate_metrics): | |
if not file_name: | |
raise gr.Error("Please paste your JSON file.") | |
df = pl.read_ndjson(file_name) | |
required_columns = [ | |
"filename", | |
"inference_start", | |
"inference_end", | |
"inference_total", | |
"duration", | |
"reference", | |
"prediction", | |
] | |
required_columns_batch = [ | |
"inference_start", | |
"inference_end", | |
"inference_total", | |
"filenames", | |
"durations", | |
"references", | |
"predictions", | |
] | |
if _batch_mode: | |
if not all(col in df.columns for col in required_columns_batch): | |
raise gr.Error( | |
f"Please provide a JSONL file with the following columns: {required_columns_batch}" | |
) | |
else: | |
if not all(col in df.columns for col in required_columns): | |
raise gr.Error( | |
f"Please provide a JSONL file with the following columns: {required_columns}" | |
) | |
# exclude inference_start, inference_end | |
if _batch_mode: | |
df = df.drop(["inference_start", "inference_end", "filenames"]) | |
else: | |
df = df.drop(["inference_start", "inference_end", "filename"]) | |
# round "inference_total" field to 2 decimal places | |
df = df.with_columns(pl.col("inference_total").round(2).alias("elapsed")) | |
df = df.drop(["inference_total"]) | |
# reassign columns | |
if _batch_mode: | |
if _calculate_metrics: | |
wer_values = Parallel(n_jobs=-1)( | |
delayed(compute_batch_wer)(row["predictions"], row["references"]) | |
for row in df.iter_rows(named=True) | |
) | |
cer_values = Parallel(n_jobs=-1)( | |
delayed(compute_batch_cer)(row["predictions"], row["references"]) | |
for row in df.iter_rows(named=True) | |
) | |
df.insert_column(2, pl.Series("wer", wer_values)) | |
df.insert_column(3, pl.Series("cer", cer_values)) | |
fields = [ | |
"elapsed", | |
"durations", | |
"wer", | |
"cer", | |
"predictions", | |
"references", | |
] | |
else: | |
fields = [ | |
"elapsed", | |
"durations", | |
"predictions", | |
"references", | |
] | |
else: | |
if _calculate_metrics: | |
wer_values = Parallel(n_jobs=-1)( | |
delayed(compute_wer)(row["prediction"], row["reference"]) | |
for row in df.iter_rows(named=True) | |
) | |
cer_values = Parallel(n_jobs=-1)( | |
delayed(compute_cer)(row["prediction"], row["reference"]) | |
for row in df.iter_rows(named=True) | |
) | |
df.insert_column(2, pl.Series("wer", wer_values)) | |
df.insert_column(3, pl.Series("cer", cer_values)) | |
fields = [ | |
"elapsed", | |
"duration", | |
"wer", | |
"cer", | |
"prediction", | |
"reference", | |
] | |
else: | |
fields = [ | |
"elapsed", | |
"duration", | |
"prediction", | |
"reference", | |
] | |
return df.select(fields) | |
demo = gr.Blocks( | |
title=title, | |
analytics_enabled=False, | |
theme=gr.themes.Base(), | |
) | |
with demo: | |
gr.Markdown(description_head) | |
gr.Markdown("## Usage") | |
with gr.Row(): | |
df = gr.DataFrame( | |
label="Dataframe", | |
show_search="search", | |
show_copy_button=True, | |
show_row_numbers=True, | |
pinned_columns=1, | |
) | |
with gr.Row(): | |
with gr.Column(): | |
jsonl_file = gr.File(label="A JSONL file") | |
batch_mode = gr.Checkbox( | |
label="Use batch mode", | |
) | |
calculate_metrics = gr.Checkbox( | |
label="Calculate WER/CER metrics", | |
value=False, | |
) | |
gr.Button("Show").click( | |
inference, | |
inputs=[jsonl_file, batch_mode, calculate_metrics], | |
outputs=df, | |
) | |
with gr.Row(): | |
gr.Examples( | |
label="Choose an example", | |
inputs=[jsonl_file, batch_mode, calculate_metrics], | |
examples=examples, | |
) | |
gr.Markdown(description_foot) | |
gr.Markdown("### Gradio app uses:") | |
gr.Markdown(tech_env) | |
gr.Markdown(tech_libraries) | |
if __name__ == "__main__": | |
demo.queue() | |
demo.launch() | |