huckiyang commited on
Commit
ad915da
·
1 Parent(s): 652bcfd
Files changed (1) hide show
  1. app.py +94 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ from datasets import load_dataset
4
+ import jiwer
5
+ import numpy as np
6
+
7
+ # Load the dataset
8
+ def load_data():
9
+ dataset = load_dataset("GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction")
10
+ return dataset
11
+
12
+ # Calculate WER for a group of examples
13
+ def calculate_wer(examples):
14
+ if not examples:
15
+ return 0.0
16
+
17
+ hypotheses = [ex["hypothesis_concatenated"].split('.')[0].strip() for ex in examples]
18
+ transcriptions = [ex["transcription"].strip() for ex in examples]
19
+
20
+ wer = jiwer.wer(transcriptions, hypotheses)
21
+ return wer
22
+
23
+ # Get WER metrics by source and split
24
+ def get_wer_metrics(dataset):
25
+ results = []
26
+
27
+ # Get unique sources
28
+ train_sources = set([ex["source"] for ex in dataset["train"]])
29
+ test_sources = set([ex["source"] for ex in dataset["test"]])
30
+ all_sources = sorted(list(train_sources.union(test_sources)))
31
+
32
+ # Calculate WER for each source in train split
33
+ for source in all_sources:
34
+ train_examples = [ex for ex in dataset["train"] if ex["source"] == source]
35
+ train_count = len(train_examples)
36
+ train_wer = calculate_wer(train_examples) if train_count > 0 else np.nan
37
+
38
+ test_examples = [ex for ex in dataset["test"] if ex["source"] == source]
39
+ test_count = len(test_examples)
40
+ test_wer = calculate_wer(test_examples) if test_count > 0 else np.nan
41
+
42
+ results.append({
43
+ "Source": source,
44
+ "Train Count": train_count,
45
+ "Train WER": train_wer,
46
+ "Test Count": test_count,
47
+ "Test WER": test_wer
48
+ })
49
+
50
+ # Add overall metrics
51
+ train_wer = calculate_wer(dataset["train"])
52
+ test_wer = calculate_wer(dataset["test"])
53
+
54
+ results.append({
55
+ "Source": "OVERALL",
56
+ "Train Count": len(dataset["train"]),
57
+ "Train WER": train_wer,
58
+ "Test Count": len(dataset["test"]),
59
+ "Test WER": test_wer
60
+ })
61
+
62
+ return pd.DataFrame(results)
63
+
64
+ # Format the dataframe for display
65
+ def format_dataframe(df):
66
+ df["Train WER"] = df["Train WER"].apply(lambda x: f"{x:.4f}" if not pd.isna(x) else "N/A")
67
+ df["Test WER"] = df["Test WER"].apply(lambda x: f"{x:.4f}" if not pd.isna(x) else "N/A")
68
+ return df
69
+
70
+ # Main function to create the leaderboard
71
+ def create_leaderboard():
72
+ try:
73
+ dataset = load_data()
74
+ metrics_df = get_wer_metrics(dataset)
75
+ formatted_df = format_dataframe(metrics_df)
76
+ return formatted_df
77
+ except Exception as e:
78
+ return pd.DataFrame({"Error": [str(e)]})
79
+
80
+ # Create the Gradio interface
81
+ with gr.Blocks(title="ASR Text Correction Leaderboard") as demo:
82
+ gr.Markdown("# ASR Text Correction Baseline WER Leaderboard")
83
+ gr.Markdown("Word Error Rate (WER) metrics for GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction dataset")
84
+
85
+ with gr.Row():
86
+ refresh_btn = gr.Button("Refresh Leaderboard")
87
+
88
+ with gr.Row():
89
+ leaderboard = gr.DataFrame(create_leaderboard())
90
+
91
+ refresh_btn.click(create_leaderboard, outputs=leaderboard)
92
+
93
+ if __name__ == "__main__":
94
+ demo.launch()