see-asr-outputs / app.py
Yehor's picture
Use all metrics
9581c74
import sys
from importlib.metadata import version
import evaluate
import polars as pl
import polars_distance as pld
import gradio as gr
# Load evaluators
wer = evaluate.load("wer")
cer = evaluate.load("cer")
# Config
title = "See ASR Outputs"
# https://www.tablesgenerator.com/markdown_tables
authors_table = """
## Authors
Follow them on social networks and **contact** if you need any help or have any questions:
| <img src="https://avatars.githubusercontent.com/u/7875085?v=4" width="100"> **Yehor Smoliakov** |
|-------------------------------------------------------------------------------------------------|
| https://t.me/smlkw in Telegram |
| https://x.com/yehor_smoliakov at X |
| https://github.com/egorsmkv at GitHub |
| https://huggingface.co/Yehor at Hugging Face |
| or use [email protected] |
""".strip()
examples = [
["evaluation_results.jsonl", False, True, False],
["evaluation_results_batch.jsonl", True, False, False],
]
description_head = f"""
# {title}
## Overview
See generated JSONL files made by ASR models as a dataframe. Also, this app calculates WER and CER metrics for each row.
""".strip()
description_foot = f"""
{authors_table}
""".strip()
metrics_value = """
Metrics will appear here.
""".strip()
tech_env = f"""
#### Environment
- Python: {sys.version}
""".strip()
tech_libraries = f"""
#### Libraries
- gradio: {version("gradio")}
- jiwer: {version("jiwer")}
- evaluate: {version("evaluate")}
- pandas: {version("pandas")}
- polars: {version("polars")}
- polars-distance: {version("polars_distance")}
""".strip()
def compute_wer(prediction, reference):
return round(wer.compute(predictions=[prediction], references=[reference]), 4)
def compute_cer(prediction, reference):
return round(cer.compute(predictions=[prediction], references=[reference]), 4)
def process_file(file_name, _batch_mode, _calculate_distance, _calculate_metrics):
if not file_name:
raise gr.Error("Please paste your JSON file.")
df = pl.read_ndjson(file_name)
required_columns = [
"filename",
"inference_start",
"inference_end",
"inference_total",
"duration",
"reference",
"prediction",
]
required_columns_batch = [
"inference_start",
"inference_end",
"inference_total",
"filenames",
"durations",
"references",
"predictions",
]
if _batch_mode:
if not all(col in df.columns for col in required_columns_batch):
raise gr.Error(
f"Please provide a JSONL file with the following columns: {required_columns_batch}"
)
else:
if not all(col in df.columns for col in required_columns):
raise gr.Error(
f"Please provide a JSONL file with the following columns: {required_columns}"
)
# exclude inference_start, inference_end
if _batch_mode:
df = df.drop(
["inference_total", "inference_start", "inference_end", "filenames"]
)
else:
df = df.drop(
["inference_total", "inference_start", "inference_end", "filename"]
)
if _batch_mode:
predictions = []
references = []
for row in df.iter_rows(named=True):
for idx, prediction in enumerate(row["predictions"]):
reference = row["references"][idx]
predictions.append(prediction)
references.append(reference)
df = pl.DataFrame(
{
"prediction": predictions,
"reference": references,
}
)
if _calculate_metrics:
# Pandas is needed for applying functions
df_pd = df.to_pandas()
df_pd["wer"] = df_pd.apply(
lambda row: compute_wer(row["prediction"], row["reference"]),
axis=1,
)
df_pd["cer"] = df_pd.apply(
lambda row: compute_cer(row["prediction"], row["reference"]),
axis=1,
)
fields = [
"wer",
"cer",
"prediction",
"reference",
]
df = pl.DataFrame(df_pd)
else:
fields = [
"prediction",
"reference",
]
df = df.select(fields)
if _calculate_distance:
df = df.with_columns(
pld.col("prediction").dist_str.levenshtein("reference").alias("distance")
)
# add distance to the first position
fields = [
"distance",
*fields,
]
df = df.select(fields)
return df
demo = gr.Blocks(
title=title,
analytics_enabled=False,
theme=gr.themes.Base(),
)
with demo:
gr.Markdown(description_head)
gr.Markdown("## Usage")
with gr.Row():
df = gr.DataFrame(
label="Dataframe",
show_search="search",
show_row_numbers=True,
pinned_columns=1,
)
with gr.Row():
with gr.Column():
jsonl_file = gr.File(label="A JSONL file")
batch_mode = gr.Checkbox(
label="Use batch mode",
)
calculate_distance = gr.Checkbox(
label="Calculate Levenshtein distance",
value=False,
)
calculate_metrics = gr.Checkbox(
label="Calculate WER/CER metrics",
value=False,
)
gr.Button("Show").click(
process_file,
inputs=[jsonl_file, batch_mode, calculate_distance, calculate_metrics],
outputs=df,
)
with gr.Row():
gr.Examples(
label="Choose an example",
inputs=[jsonl_file, batch_mode, calculate_distance, calculate_metrics],
examples=examples,
)
gr.Markdown(description_foot)
gr.Markdown("### Gradio app uses:")
gr.Markdown(tech_env)
gr.Markdown(tech_libraries)
if __name__ == "__main__":
demo.queue()
demo.launch()