see-asr-outputs / app.py
Yehor's picture
Add search
9484ee3
raw
history blame
6.85 kB
import sys
from importlib.metadata import version
import evaluate
import polars as pl
import gradio as gr
from joblib import Parallel, delayed
# Load evaluators
wer = evaluate.load("wer")
cer = evaluate.load("cer")
# Config
title = "See ASR Outputs"
# https://www.tablesgenerator.com/markdown_tables
authors_table = """
## Authors
Follow them on social networks and **contact** if you need any help or have any questions:
| <img src="https://avatars.githubusercontent.com/u/7875085?v=4" width="100"> **Yehor Smoliakov** |
|-------------------------------------------------------------------------------------------------|
| https://t.me/smlkw in Telegram |
| https://x.com/yehor_smoliakov at X |
| https://github.com/egorsmkv at GitHub |
| https://huggingface.co/Yehor at Hugging Face |
| or use [email protected] |
""".strip()
examples = [
["evaluation_results.jsonl", False, True],
["evaluation_results_batch.jsonl", True, True],
]
description_head = f"""
# {title}
## Overview
See generated JSONL files made by ASR models as a dataframe. Also, this app calculates WER and CER metrics for each row.
""".strip()
description_foot = f"""
{authors_table}
""".strip()
metrics_value = """
Metrics will appear here.
""".strip()
tech_env = f"""
#### Environment
- Python: {sys.version}
""".strip()
tech_libraries = f"""
#### Libraries
- gradio: {version("gradio")}
- jiwer: {version("jiwer")}
- evaluate: {version("evaluate")}
- polars: {version("polars")}
""".strip()
def compute_wer(prediction, reference):
return round(wer.compute(predictions=[prediction], references=[reference]), 4)
def compute_cer(prediction, reference):
return round(cer.compute(predictions=[prediction], references=[reference]), 4)
def compute_batch_wer(predictions, references):
return round(wer.compute(predictions=predictions, references=references), 4)
def compute_batch_cer(predictions, references):
return round(cer.compute(predictions=predictions, references=references), 4)
def inference(file_name, _batch_mode, _calculate_metrics):
if not file_name:
raise gr.Error("Please paste your JSON file.")
df = pl.read_ndjson(file_name)
required_columns = [
"filename",
"inference_start",
"inference_end",
"inference_total",
"duration",
"reference",
"prediction",
]
required_columns_batch = [
"inference_start",
"inference_end",
"inference_total",
"filenames",
"durations",
"references",
"predictions",
]
if _batch_mode:
if not all(col in df.columns for col in required_columns_batch):
raise gr.Error(
f"Please provide a JSONL file with the following columns: {required_columns_batch}"
)
else:
if not all(col in df.columns for col in required_columns):
raise gr.Error(
f"Please provide a JSONL file with the following columns: {required_columns}"
)
# exclude inference_start, inference_end
if _batch_mode:
df = df.drop(["inference_start", "inference_end", "filenames"])
else:
df = df.drop(["inference_start", "inference_end", "filename"])
# round "inference_total" field to 2 decimal places
df = df.with_columns(pl.col("inference_total").round(2).alias("elapsed"))
df = df.drop(["inference_total"])
# reassign columns
if _batch_mode:
if _calculate_metrics:
wer_values = Parallel(n_jobs=-1)(
delayed(compute_batch_wer)(row["predictions"], row["references"])
for row in df.iter_rows(named=True)
)
cer_values = Parallel(n_jobs=-1)(
delayed(compute_batch_cer)(row["predictions"], row["references"])
for row in df.iter_rows(named=True)
)
df.insert_column(2, pl.Series("wer", wer_values))
df.insert_column(3, pl.Series("cer", cer_values))
fields = [
"elapsed",
"durations",
"wer",
"cer",
"predictions",
"references",
]
else:
fields = [
"elapsed",
"durations",
"predictions",
"references",
]
else:
if _calculate_metrics:
wer_values = Parallel(n_jobs=-1)(
delayed(compute_wer)(row["prediction"], row["reference"])
for row in df.iter_rows(named=True)
)
cer_values = Parallel(n_jobs=-1)(
delayed(compute_cer)(row["prediction"], row["reference"])
for row in df.iter_rows(named=True)
)
df.insert_column(2, pl.Series("wer", wer_values))
df.insert_column(3, pl.Series("cer", cer_values))
fields = [
"elapsed",
"duration",
"wer",
"cer",
"prediction",
"reference",
]
else:
fields = [
"elapsed",
"duration",
"prediction",
"reference",
]
return df.select(fields)
demo = gr.Blocks(
title=title,
analytics_enabled=False,
theme=gr.themes.Base(),
)
with demo:
gr.Markdown(description_head)
gr.Markdown("## Usage")
with gr.Row():
df = gr.DataFrame(
label="Dataframe",
show_search="search",
show_copy_button=True,
show_row_numbers=True,
pinned_columns=1,
)
with gr.Row():
with gr.Column():
jsonl_file = gr.File(label="A JSONL file")
batch_mode = gr.Checkbox(
label="Use batch mode",
)
calculate_metrics = gr.Checkbox(
label="Calculate WER/CER metrics",
value=False,
)
gr.Button("Show").click(
inference,
inputs=[jsonl_file, batch_mode, calculate_metrics],
outputs=df,
)
with gr.Row():
gr.Examples(
label="Choose an example",
inputs=[jsonl_file, batch_mode, calculate_metrics],
examples=examples,
)
gr.Markdown(description_foot)
gr.Markdown("### Gradio app uses:")
gr.Markdown(tech_env)
gr.Markdown(tech_libraries)
if __name__ == "__main__":
demo.queue()
demo.launch()