Yehor's picture
Init
5545d25
raw
history blame
4.59 kB
import sys
from importlib.metadata import version
import evaluate
import polars as pl
import gradio as gr
# Load evaluators
wer = evaluate.load("wer")
cer = evaluate.load("cer")
# Config
concurrency_limit = 5
title = "Evaluate ASR Outputs"
# https://www.tablesgenerator.com/markdown_tables
authors_table = """
## Authors
Follow them on social networks and **contact** if you need any help or have any questions:
| <img src="https://avatars.githubusercontent.com/u/7875085?v=4" width="100"> **Yehor Smoliakov** |
|-------------------------------------------------------------------------------------------------|
| https://t.me/smlkw in Telegram |
| https://x.com/yehor_smoliakov at X |
| https://github.com/egorsmkv at GitHub |
| https://huggingface.co/Yehor at Hugging Face |
| or use [email protected] |
""".strip()
examples = [
["evaluation_results.jsonl", True],
]
description_head = f"""
# {title}
## Overview
Upload a JSONL file generated by the ASR model.
""".strip()
description_foot = f"""
{authors_table}
""".strip()
metrics_value = """
Metrics will appear here.
""".strip()
tech_env = f"""
#### Environment
- Python: {sys.version}
""".strip()
tech_libraries = f"""
#### Libraries
- evaluate: {version('evaluate')}
- gradio: {version('gradio')}
- jiwer: {version('jiwer')}
- polars: {version('polars')}
""".strip()
def clean_value(x):
return x.replace('’', "'").strip().lower().replace(',', '').replace('.', '').replace('?', '').replace('!', '').replace('–', '').replace('«', '').replace('»', '')
def inference(file_name, clear_punctuation, progress=gr.Progress()):
if not file_name:
raise gr.Error("Please paste your JSON file.")
progress(0, desc="Calculating...")
df = pl.read_ndjson(file_name)
inference_seconds = df['inference_total'].sum()
duration_seconds = df['duration'].sum()
rtf = inference_seconds / duration_seconds
references = df['reference']
if clear_punctuation:
predictions = df['prediction'].map_elements(clean_value)
else:
predictions = df['prediction']
# Evaluate
wer_value = round(
wer.compute(predictions=predictions, references=references), 4
)
cer_value = round(
cer.compute(predictions=predictions, references=references), 4
)
inference_time = inference_seconds
audio_duration = duration_seconds
rtf = inference_time / audio_duration
results = []
results.append(f"Metrics using `evaluate` library:")
results.append('')
results.append(f"- WER: {wer_value} metric, {round(wer_value*100, 4)}%")
results.append(f"- CER: {cer_value} metric, {round(cer_value*100, 4)}%")
results.append('')
results.append(f"- Accuracy on words: {100 - 100 * wer_value}%")
results.append(f"- Accuracy on chars: {100 - 100 * cer_value}%")
results.append('')
results.append(f"- Inference time: {round(inference_time, 4)} seconds, {round(inference_time/60, 4)} mins, {round(inference_time/60/60, 4)} hours")
results.append(f"- Audio duration: {round(audio_duration, 4)} seconds, {round(audio_duration/60/60, 4)} hours")
results.append('')
results.append(f"- RTF: {round(rtf, 4)}")
return "\n".join(results)
demo = gr.Blocks(
title=title,
analytics_enabled=False,
theme=gr.themes.Base(),
)
with demo:
gr.Markdown(description_head)
gr.Markdown("## Usage")
with gr.Row():
with gr.Column():
jsonl_file = gr.File(label="A JSONL file")
clear_punctuation = gr.Checkbox(
label="Clear punctuation, some chars and convert to lowercase",
)
metrics = gr.Textbox(
label="Metrics",
placeholder=metrics_value,
show_copy_button=True,
)
gr.Button("Calculate").click(
inference,
concurrency_limit=concurrency_limit,
inputs=[jsonl_file, clear_punctuation],
outputs=metrics,
)
with gr.Row():
gr.Examples(label="Choose an example", inputs=[jsonl_file, clear_punctuation], examples=examples)
gr.Markdown(description_foot)
gr.Markdown("### Gradio app uses:")
gr.Markdown(tech_env)
gr.Markdown(tech_libraries)
if __name__ == "__main__":
demo.queue()
demo.launch()