Spaces:
Running
Running
import sys | |
from importlib.metadata import version | |
import evaluate | |
import polars as pl | |
import polars_distance as pld | |
import gradio as gr | |
# Load evaluators | |
wer = evaluate.load("wer") | |
cer = evaluate.load("cer") | |
# Config | |
title = "See ASR Outputs" | |
# https://www.tablesgenerator.com/markdown_tables | |
authors_table = """ | |
## Authors | |
Follow them on social networks and **contact** if you need any help or have any questions: | |
| <img src="https://avatars.githubusercontent.com/u/7875085?v=4" width="100"> **Yehor Smoliakov** | | |
|-------------------------------------------------------------------------------------------------| | |
| https://t.me/smlkw in Telegram | | |
| https://x.com/yehor_smoliakov at X | | |
| https://github.com/egorsmkv at GitHub | | |
| https://huggingface.co/Yehor at Hugging Face | | |
| or use [email protected] | | |
""".strip() | |
examples = [ | |
["evaluation_results.jsonl", False, True, False], | |
["evaluation_results_batch.jsonl", True, False, False], | |
] | |
description_head = f""" | |
# {title} | |
## Overview | |
See generated JSONL files made by ASR models as a dataframe. Also, this app calculates WER and CER metrics for each row. | |
""".strip() | |
description_foot = f""" | |
{authors_table} | |
""".strip() | |
metrics_value = """ | |
Metrics will appear here. | |
""".strip() | |
tech_env = f""" | |
#### Environment | |
- Python: {sys.version} | |
""".strip() | |
tech_libraries = f""" | |
#### Libraries | |
- gradio: {version("gradio")} | |
- jiwer: {version("jiwer")} | |
- evaluate: {version("evaluate")} | |
- pandas: {version("pandas")} | |
- polars: {version("polars")} | |
- polars-distance: {version("polars_distance")} | |
""".strip() | |
def compute_wer(prediction, reference): | |
return round(wer.compute(predictions=[prediction], references=[reference]), 4) | |
def compute_cer(prediction, reference): | |
return round(cer.compute(predictions=[prediction], references=[reference]), 4) | |
def process_file(file_name, _batch_mode, _calculate_distance, _calculate_metrics): | |
if not file_name: | |
raise gr.Error("Please paste your JSON file.") | |
df = pl.read_ndjson(file_name) | |
required_columns = [ | |
"filename", | |
"inference_start", | |
"inference_end", | |
"inference_total", | |
"duration", | |
"reference", | |
"prediction", | |
] | |
required_columns_batch = [ | |
"inference_start", | |
"inference_end", | |
"inference_total", | |
"filenames", | |
"durations", | |
"references", | |
"predictions", | |
] | |
if _batch_mode: | |
if not all(col in df.columns for col in required_columns_batch): | |
raise gr.Error( | |
f"Please provide a JSONL file with the following columns: {required_columns_batch}" | |
) | |
else: | |
if not all(col in df.columns for col in required_columns): | |
raise gr.Error( | |
f"Please provide a JSONL file with the following columns: {required_columns}" | |
) | |
# exclude inference_start, inference_end | |
if _batch_mode: | |
df = df.drop( | |
["inference_total", "inference_start", "inference_end", "filenames"] | |
) | |
else: | |
df = df.drop( | |
["inference_total", "inference_start", "inference_end", "filename"] | |
) | |
if _batch_mode: | |
predictions = [] | |
references = [] | |
for row in df.iter_rows(named=True): | |
for idx, prediction in enumerate(row["predictions"]): | |
reference = row["references"][idx] | |
predictions.append(prediction) | |
references.append(reference) | |
df = pl.DataFrame( | |
{ | |
"prediction": predictions, | |
"reference": references, | |
} | |
) | |
if _calculate_metrics: | |
# Pandas is needed for applying functions | |
df_pd = df.to_pandas() | |
df_pd["wer"] = df_pd.apply( | |
lambda row: compute_wer(row["prediction"], row["reference"]), | |
axis=1, | |
) | |
df_pd["cer"] = df_pd.apply( | |
lambda row: compute_cer(row["prediction"], row["reference"]), | |
axis=1, | |
) | |
fields = [ | |
"wer", | |
"cer", | |
"prediction", | |
"reference", | |
] | |
df = pl.DataFrame(df_pd) | |
else: | |
fields = [ | |
"prediction", | |
"reference", | |
] | |
df = df.select(fields) | |
if _calculate_distance: | |
df = df.with_columns( | |
pld.col("prediction").dist_str.levenshtein("reference").alias("distance") | |
) | |
# add distance to the first position | |
fields = [ | |
"distance", | |
*fields, | |
] | |
df = df.select(fields) | |
return df | |
demo = gr.Blocks( | |
title=title, | |
analytics_enabled=False, | |
theme=gr.themes.Base(), | |
) | |
with demo: | |
gr.Markdown(description_head) | |
gr.Markdown("## Usage") | |
with gr.Row(): | |
df = gr.DataFrame( | |
label="Dataframe", | |
show_search="search", | |
show_row_numbers=True, | |
pinned_columns=1, | |
) | |
with gr.Row(): | |
with gr.Column(): | |
jsonl_file = gr.File(label="A JSONL file") | |
batch_mode = gr.Checkbox( | |
label="Use batch mode", | |
) | |
calculate_distance = gr.Checkbox( | |
label="Calculate Levenshtein distance", | |
value=False, | |
) | |
calculate_metrics = gr.Checkbox( | |
label="Calculate WER/CER metrics", | |
value=False, | |
) | |
gr.Button("Show").click( | |
process_file, | |
inputs=[jsonl_file, batch_mode, calculate_distance, calculate_metrics], | |
outputs=df, | |
) | |
with gr.Row(): | |
gr.Examples( | |
label="Choose an example", | |
inputs=[jsonl_file, batch_mode, calculate_distance, calculate_metrics], | |
examples=examples, | |
) | |
gr.Markdown(description_foot) | |
gr.Markdown("### Gradio app uses:") | |
gr.Markdown(tech_env) | |
gr.Markdown(tech_libraries) | |
if __name__ == "__main__": | |
demo.queue() | |
demo.launch() | |