File size: 3,111 Bytes
ad915da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
import pandas as pd
from datasets import load_dataset
import jiwer
import numpy as np

# Load the dataset
def load_data():
    dataset = load_dataset("GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction")
    return dataset

# Calculate WER for a group of examples
def calculate_wer(examples):
    if not examples:
        return 0.0
    
    hypotheses = [ex["hypothesis_concatenated"].split('.')[0].strip() for ex in examples]
    transcriptions = [ex["transcription"].strip() for ex in examples]
    
    wer = jiwer.wer(transcriptions, hypotheses)
    return wer

# Get WER metrics by source and split
def get_wer_metrics(dataset):
    results = []
    
    # Get unique sources
    train_sources = set([ex["source"] for ex in dataset["train"]])
    test_sources = set([ex["source"] for ex in dataset["test"]])
    all_sources = sorted(list(train_sources.union(test_sources)))
    
    # Calculate WER for each source in train split
    for source in all_sources:
        train_examples = [ex for ex in dataset["train"] if ex["source"] == source]
        train_count = len(train_examples)
        train_wer = calculate_wer(train_examples) if train_count > 0 else np.nan
        
        test_examples = [ex for ex in dataset["test"] if ex["source"] == source]
        test_count = len(test_examples)
        test_wer = calculate_wer(test_examples) if test_count > 0 else np.nan
        
        results.append({
            "Source": source,
            "Train Count": train_count,
            "Train WER": train_wer,
            "Test Count": test_count,
            "Test WER": test_wer
        })
    
    # Add overall metrics
    train_wer = calculate_wer(dataset["train"])
    test_wer = calculate_wer(dataset["test"])
    
    results.append({
        "Source": "OVERALL",
        "Train Count": len(dataset["train"]),
        "Train WER": train_wer,
        "Test Count": len(dataset["test"]),
        "Test WER": test_wer
    })
    
    return pd.DataFrame(results)

# Format the dataframe for display
def format_dataframe(df):
    df["Train WER"] = df["Train WER"].apply(lambda x: f"{x:.4f}" if not pd.isna(x) else "N/A")
    df["Test WER"] = df["Test WER"].apply(lambda x: f"{x:.4f}" if not pd.isna(x) else "N/A")
    return df

# Main function to create the leaderboard
def create_leaderboard():
    try:
        dataset = load_data()
        metrics_df = get_wer_metrics(dataset)
        formatted_df = format_dataframe(metrics_df)
        return formatted_df
    except Exception as e:
        return pd.DataFrame({"Error": [str(e)]})

# Create the Gradio interface
with gr.Blocks(title="ASR Text Correction Leaderboard") as demo:
    gr.Markdown("# ASR Text Correction Baseline WER Leaderboard")
    gr.Markdown("Word Error Rate (WER) metrics for GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction dataset")
    
    with gr.Row():
        refresh_btn = gr.Button("Refresh Leaderboard")
    
    with gr.Row():
        leaderboard = gr.DataFrame(create_leaderboard())
    
    refresh_btn.click(create_leaderboard, outputs=leaderboard)

if __name__ == "__main__":
    demo.launch()