sudoping01 commited on
Commit
3769468
·
verified ·
1 Parent(s): 6960dc6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -27
app.py CHANGED
@@ -1,18 +1,13 @@
1
  import gradio as gr
2
  import pandas as pd
3
  from datasets import load_dataset
4
- from jiwer import wer, cer, transforms
5
  import os
6
  from datetime import datetime
7
-
8
- # Define text normalization transform
9
- transform = transforms.Compose([
10
- transforms.RemovePunctuation(),
11
- transforms.ToLowerCase(),
12
- transforms.RemoveWhiteSpace(replace_by_space=True),
13
- ])
14
 
15
  # Load the Bambara ASR dataset
 
16
  dataset = load_dataset("sudoping01/bambara-asr-benchmark", name="default")["train"]
17
  references = {row["id"]: row["text"] for row in dataset}
18
 
@@ -20,29 +15,143 @@ references = {row["id"]: row["text"] for row in dataset}
20
  leaderboard_file = "leaderboard.csv"
21
  if not os.path.exists(leaderboard_file):
22
  pd.DataFrame(columns=["submitter", "WER", "CER", "timestamp"]).to_csv(leaderboard_file, index=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  def process_submission(submitter_name, csv_file):
25
  try:
26
  # Read and validate the uploaded CSV
27
  df = pd.read_csv(csv_file)
 
 
 
 
 
28
  if set(df.columns) != {"id", "text"}:
29
- return "Error: CSV must contain exactly 'id' and 'text' columns.", None
 
30
  if df["id"].duplicated().any():
31
- return "Error: Duplicate 'id's found in the CSV.", None
32
- if set(df["id"]) != set(references.keys()):
33
- return "Error: CSV 'id's must match the dataset 'id's.", None
 
 
 
34
 
35
- # Calculate WER and CER for each prediction
36
- wers, cers = [], []
37
- for _, row in df.iterrows():
38
- ref = references[row["id"]]
39
- pred = row["text"]
40
- wers.append(wer(ref, pred, truth_transform=transform, hypothesis_transform=transform))
41
- cers.append(cer(ref, pred, truth_transform=transform, hypothesis_transform=transform))
42
 
43
- # Compute average WER and CER
44
- avg_wer = sum(wers) / len(wers)
45
- avg_cer = sum(cers) / len(cers)
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  # Update the leaderboard
48
  leaderboard = pd.read_csv(leaderboard_file)
@@ -54,8 +163,10 @@ def process_submission(submitter_name, csv_file):
54
  leaderboard = pd.concat([leaderboard, new_entry]).sort_values("WER")
55
  leaderboard.to_csv(leaderboard_file, index=False)
56
 
57
- return "Submission processed successfully!", leaderboard
 
58
  except Exception as e:
 
59
  return f"Error processing submission: {str(e)}", None
60
 
61
  # Create the Gradio interface
@@ -63,17 +174,18 @@ with gr.Blocks(title="Bambara ASR Leaderboard") as demo:
63
  gr.Markdown(
64
  """
65
  # Bambara ASR Leaderboard
66
- Upload a CSV file with 'id' and 'text' columns to evaluate your ASR predictions.
67
- The 'id's must match those in the dataset.
68
  [View the dataset here](https://huggingface.co/datasets/MALIBA-AI/bambara_general_leaderboard_dataset).
69
-
70
  - **WER**: Word Error Rate (lower is better).
71
  - **CER**: Character Error Rate (lower is better).
72
  """
73
  )
 
74
  with gr.Row():
75
  submitter = gr.Textbox(label="Submitter Name or Model Name", placeholder="e.g., MALIBA-AI/asr")
76
  csv_upload = gr.File(label="Upload CSV File", file_types=[".csv"])
 
77
  submit_btn = gr.Button("Submit")
78
  output_msg = gr.Textbox(label="Status", interactive=False)
79
  leaderboard_display = gr.DataFrame(
@@ -88,4 +200,9 @@ with gr.Blocks(title="Bambara ASR Leaderboard") as demo:
88
  outputs=[output_msg, leaderboard_display]
89
  )
90
 
91
- demo.launch(share=True)
 
 
 
 
 
 
1
  import gradio as gr
2
  import pandas as pd
3
  from datasets import load_dataset
4
+ from jiwer import wer, cer
5
  import os
6
  from datetime import datetime
7
+ import re
 
 
 
 
 
 
8
 
9
  # Load the Bambara ASR dataset
10
+ print("Loading dataset...")
11
  dataset = load_dataset("sudoping01/bambara-asr-benchmark", name="default")["train"]
12
  references = {row["id"]: row["text"] for row in dataset}
13
 
 
15
  leaderboard_file = "leaderboard.csv"
16
  if not os.path.exists(leaderboard_file):
17
  pd.DataFrame(columns=["submitter", "WER", "CER", "timestamp"]).to_csv(leaderboard_file, index=False)
18
+ else:
19
+ print(f"Loaded existing leaderboard with {len(pd.read_csv(leaderboard_file))} entries")
20
+
21
+ def normalize_text(text):
22
+ """
23
+ Normalize text for WER/CER calculation:
24
+ - Convert to lowercase
25
+ - Remove punctuation
26
+ - Replace multiple spaces with single space
27
+ - Strip leading/trailing spaces
28
+ """
29
+ if not isinstance(text, str):
30
+ text = str(text)
31
+
32
+ # Convert to lowercase
33
+ text = text.lower()
34
+
35
+ # Remove punctuation, keeping spaces
36
+ text = re.sub(r'[^\w\s]', '', text)
37
+
38
+ # Normalize whitespace
39
+ text = re.sub(r'\s+', ' ', text).strip()
40
+
41
+ return text
42
+
43
+ def calculate_metrics(predictions_df):
44
+ """Calculate WER and CER for predictions."""
45
+ results = []
46
+
47
+ for _, row in predictions_df.iterrows():
48
+ id_val = row["id"]
49
+ if id_val not in references:
50
+ print(f"Warning: ID {id_val} not found in references")
51
+ continue
52
+
53
+ reference = normalize_text(references[id_val])
54
+ hypothesis = normalize_text(row["text"])
55
+
56
+ # Print detailed info for first few entries
57
+ if len(results) < 5:
58
+ print(f"ID: {id_val}")
59
+ print(f"Reference: '{reference}'")
60
+ print(f"Hypothesis: '{hypothesis}'")
61
+
62
+ # Skip empty strings
63
+ if not reference or not hypothesis:
64
+ print(f"Warning: Empty reference or hypothesis for ID {id_val}")
65
+ continue
66
+
67
+ # Split into words for jiwer
68
+ reference_words = reference.split()
69
+ hypothesis_words = hypothesis.split()
70
+
71
+ if len(results) < 5:
72
+ print(f"Reference words: {reference_words}")
73
+ print(f"Hypothesis words: {hypothesis_words}")
74
+
75
+ # Calculate metrics
76
+ try:
77
+ # Make sure we're not comparing identical strings
78
+ if reference == hypothesis:
79
+ print(f"Warning: Identical strings for ID {id_val}")
80
+ # Force a small difference if the strings are identical
81
+ # This is for debugging - remove in production if needed
82
+ if len(hypothesis_words) > 0:
83
+ # Add a dummy word to force non-zero WER
84
+ hypothesis_words.append("dummy_debug_token")
85
+ hypothesis = " ".join(hypothesis_words)
86
+
87
+ # Calculate WER and CER
88
+ sample_wer = wer(reference, hypothesis)
89
+ sample_cer = cer(reference, hypothesis)
90
+
91
+ if len(results) < 5:
92
+ print(f"WER: {sample_wer}, CER: {sample_cer}")
93
+
94
+ results.append({
95
+ "id": id_val,
96
+ "reference": reference,
97
+ "hypothesis": hypothesis,
98
+ "wer": sample_wer,
99
+ "cer": sample_cer
100
+ })
101
+ except Exception as e:
102
+ print(f"Error calculating metrics for ID {id_val}: {str(e)}")
103
+
104
+ if not results:
105
+ raise ValueError("No valid samples for WER/CER calculation")
106
+
107
+ # Calculate average metrics
108
+ avg_wer = sum(item["wer"] for item in results) / len(results)
109
+ avg_cer = sum(item["cer"] for item in results) / len(results)
110
+
111
+ return avg_wer, avg_cer, results
112
 
113
  def process_submission(submitter_name, csv_file):
114
  try:
115
  # Read and validate the uploaded CSV
116
  df = pd.read_csv(csv_file)
117
+ print(f"Processing submission from {submitter_name} with {len(df)} rows")
118
+
119
+ if len(df) == 0:
120
+ return "Error: Uploaded CSV is empty.", None
121
+
122
  if set(df.columns) != {"id", "text"}:
123
+ return f"Error: CSV must contain exactly 'id' and 'text' columns. Found: {', '.join(df.columns)}", None
124
+
125
  if df["id"].duplicated().any():
126
+ dup_ids = df[df["id"].duplicated()]["id"].unique()
127
+ return f"Error: Duplicate IDs found: {', '.join(map(str, dup_ids[:5]))}", None
128
+
129
+ # Check if IDs match the reference dataset
130
+ missing_ids = set(references.keys()) - set(df["id"])
131
+ extra_ids = set(df["id"]) - set(references.keys())
132
 
133
+ if missing_ids:
134
+ return f"Error: Missing {len(missing_ids)} IDs in submission. First few missing: {', '.join(map(str, list(missing_ids)[:5]))}", None
135
+
136
+ if extra_ids:
137
+ return f"Error: Found {len(extra_ids)} extra IDs not in reference dataset. First few extra: {', '.join(map(str, list(extra_ids)[:5]))}", None
 
 
138
 
139
+ # Calculate WER and CER
140
+ try:
141
+ avg_wer, avg_cer, detailed_results = calculate_metrics(df)
142
+
143
+ # Debug information
144
+ print(f"Calculated metrics - WER: {avg_wer:.4f}, CER: {avg_cer:.4f}")
145
+ print(f"Processed {len(detailed_results)} valid samples")
146
+
147
+ # Check for suspiciously low values
148
+ if avg_wer < 0.001:
149
+ print("WARNING: WER is extremely low - likely an error")
150
+ return "Error: WER calculation yielded suspicious results (near-zero). Please check your submission CSV.", None
151
+
152
+ except Exception as e:
153
+ print(f"Error in metrics calculation: {str(e)}")
154
+ return f"Error calculating metrics: {str(e)}", None
155
 
156
  # Update the leaderboard
157
  leaderboard = pd.read_csv(leaderboard_file)
 
163
  leaderboard = pd.concat([leaderboard, new_entry]).sort_values("WER")
164
  leaderboard.to_csv(leaderboard_file, index=False)
165
 
166
+ return f"Submission processed successfully! WER: {avg_wer:.4f}, CER: {avg_cer:.4f}", leaderboard
167
+
168
  except Exception as e:
169
+ print(f"Error processing submission: {str(e)}")
170
  return f"Error processing submission: {str(e)}", None
171
 
172
  # Create the Gradio interface
 
174
  gr.Markdown(
175
  """
176
  # Bambara ASR Leaderboard
177
+ Upload a CSV file with 'id' and 'text' columns to evaluate your ASR predictions.
178
+ The 'id's must match those in the dataset.
179
  [View the dataset here](https://huggingface.co/datasets/MALIBA-AI/bambara_general_leaderboard_dataset).
 
180
  - **WER**: Word Error Rate (lower is better).
181
  - **CER**: Character Error Rate (lower is better).
182
  """
183
  )
184
+
185
  with gr.Row():
186
  submitter = gr.Textbox(label="Submitter Name or Model Name", placeholder="e.g., MALIBA-AI/asr")
187
  csv_upload = gr.File(label="Upload CSV File", file_types=[".csv"])
188
+
189
  submit_btn = gr.Button("Submit")
190
  output_msg = gr.Textbox(label="Status", interactive=False)
191
  leaderboard_display = gr.DataFrame(
 
200
  outputs=[output_msg, leaderboard_display]
201
  )
202
 
203
+ # Print startup message
204
+ print("Starting Bambara ASR Leaderboard app...")
205
+
206
+ # Launch the app
207
+ if __name__ == "__main__":
208
+ demo.launch(share=True)