|
import gradio as gr |
|
import pandas as pd |
|
from datasets import load_dataset |
|
import jiwer |
|
import numpy as np |
|
|
|
|
|
def load_data(): |
|
dataset = load_dataset("GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction") |
|
return dataset |
|
|
|
|
|
def calculate_wer(examples): |
|
if not examples: |
|
return 0.0 |
|
|
|
hypotheses = [ex["hypothesis_concatenated"].split('.')[0].strip() for ex in examples] |
|
transcriptions = [ex["transcription"].strip() for ex in examples] |
|
|
|
wer = jiwer.wer(transcriptions, hypotheses) |
|
return wer |
|
|
|
|
|
def get_wer_metrics(dataset): |
|
results = [] |
|
|
|
|
|
train_sources = set([ex["source"] for ex in dataset["train"]]) |
|
test_sources = set([ex["source"] for ex in dataset["test"]]) |
|
all_sources = sorted(list(train_sources.union(test_sources))) |
|
|
|
|
|
for source in all_sources: |
|
train_examples = [ex for ex in dataset["train"] if ex["source"] == source] |
|
train_count = len(train_examples) |
|
train_wer = calculate_wer(train_examples) if train_count > 0 else np.nan |
|
|
|
test_examples = [ex for ex in dataset["test"] if ex["source"] == source] |
|
test_count = len(test_examples) |
|
test_wer = calculate_wer(test_examples) if test_count > 0 else np.nan |
|
|
|
results.append({ |
|
"Source": source, |
|
"Train Count": train_count, |
|
"Train WER": train_wer, |
|
"Test Count": test_count, |
|
"Test WER": test_wer |
|
}) |
|
|
|
|
|
train_wer = calculate_wer(dataset["train"]) |
|
test_wer = calculate_wer(dataset["test"]) |
|
|
|
results.append({ |
|
"Source": "OVERALL", |
|
"Train Count": len(dataset["train"]), |
|
"Train WER": train_wer, |
|
"Test Count": len(dataset["test"]), |
|
"Test WER": test_wer |
|
}) |
|
|
|
return pd.DataFrame(results) |
|
|
|
|
|
def format_dataframe(df): |
|
df["Train WER"] = df["Train WER"].apply(lambda x: f"{x:.4f}" if not pd.isna(x) else "N/A") |
|
df["Test WER"] = df["Test WER"].apply(lambda x: f"{x:.4f}" if not pd.isna(x) else "N/A") |
|
return df |
|
|
|
|
|
def create_leaderboard(): |
|
try: |
|
dataset = load_data() |
|
metrics_df = get_wer_metrics(dataset) |
|
formatted_df = format_dataframe(metrics_df) |
|
return formatted_df |
|
except Exception as e: |
|
return pd.DataFrame({"Error": [str(e)]}) |
|
|
|
|
|
with gr.Blocks(title="ASR Text Correction Leaderboard") as demo: |
|
gr.Markdown("# ASR Text Correction Baseline WER Leaderboard") |
|
gr.Markdown("Word Error Rate (WER) metrics for GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction dataset") |
|
|
|
with gr.Row(): |
|
refresh_btn = gr.Button("Refresh Leaderboard") |
|
|
|
with gr.Row(): |
|
leaderboard = gr.DataFrame(create_leaderboard()) |
|
|
|
refresh_btn.click(create_leaderboard, outputs=leaderboard) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |