sudoping01 commited on
Commit
6960dc6
·
verified ·
1 Parent(s): 60c60cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -45
app.py CHANGED
@@ -1,10 +1,17 @@
1
  import gradio as gr
2
  import pandas as pd
3
  from datasets import load_dataset
4
- from jiwer import wer, cer
5
  import os
6
  from datetime import datetime
7
 
 
 
 
 
 
 
 
8
  # Load the Bambara ASR dataset
9
  dataset = load_dataset("sudoping01/bambara-asr-benchmark", name="default")["train"]
10
  references = {row["id"]: row["text"] for row in dataset}
@@ -14,59 +21,26 @@ leaderboard_file = "leaderboard.csv"
14
  if not os.path.exists(leaderboard_file):
15
  pd.DataFrame(columns=["submitter", "WER", "CER", "timestamp"]).to_csv(leaderboard_file, index=False)
16
 
17
- def preprocess_text(text):
18
- """
19
- Custom text preprocessing to handle Bambara text properly
20
- """
21
- # Convert to string in case it's not
22
- text = str(text)
23
-
24
- # Remove punctuation
25
- for punct in [',', '.', '!', '?', ';', ':', '"', "'"]:
26
- text = text.replace(punct, '')
27
-
28
- # Convert to lowercase
29
- text = text.lower()
30
-
31
- # Normalize whitespace
32
- text = ' '.join(text.split())
33
-
34
- return text
35
-
36
  def process_submission(submitter_name, csv_file):
37
  try:
38
  # Read and validate the uploaded CSV
39
  df = pd.read_csv(csv_file)
40
-
41
  if set(df.columns) != {"id", "text"}:
42
  return "Error: CSV must contain exactly 'id' and 'text' columns.", None
43
-
44
  if df["id"].duplicated().any():
45
  return "Error: Duplicate 'id's found in the CSV.", None
46
-
47
  if set(df["id"]) != set(references.keys()):
48
  return "Error: CSV 'id's must match the dataset 'id's.", None
49
-
50
  # Calculate WER and CER for each prediction
51
  wers, cers = [], []
52
-
53
  for _, row in df.iterrows():
54
- ref = preprocess_text(references[row["id"]])
55
- pred = preprocess_text(row["text"])
56
-
57
- # Check if either text is empty after preprocessing
58
- if not ref or not pred:
59
- continue
60
-
61
- # Calculate metrics with no transform (we did preprocessing already)
62
- # This avoids the error with jiwer's transform
63
- wers.append(wer(ref, pred))
64
- cers.append(cer(ref, pred))
65
-
66
  # Compute average WER and CER
67
- if not wers or not cers:
68
- return "Error: No valid text pairs for evaluation after preprocessing.", None
69
-
70
  avg_wer = sum(wers) / len(wers)
71
  avg_cer = sum(cers) / len(cers)
72
 
@@ -81,7 +55,6 @@ def process_submission(submitter_name, csv_file):
81
  leaderboard.to_csv(leaderboard_file, index=False)
82
 
83
  return "Submission processed successfully!", leaderboard
84
-
85
  except Exception as e:
86
  return f"Error processing submission: {str(e)}", None
87
 
@@ -90,18 +63,17 @@ with gr.Blocks(title="Bambara ASR Leaderboard") as demo:
90
  gr.Markdown(
91
  """
92
  # Bambara ASR Leaderboard
93
- Upload a CSV file with 'id' and 'text' columns to evaluate your ASR predictions.
94
- The 'id's must match those in the dataset.
95
  [View the dataset here](https://huggingface.co/datasets/MALIBA-AI/bambara_general_leaderboard_dataset).
 
96
  - **WER**: Word Error Rate (lower is better).
97
  - **CER**: Character Error Rate (lower is better).
98
  """
99
  )
100
-
101
  with gr.Row():
102
  submitter = gr.Textbox(label="Submitter Name or Model Name", placeholder="e.g., MALIBA-AI/asr")
103
  csv_upload = gr.File(label="Upload CSV File", file_types=[".csv"])
104
-
105
  submit_btn = gr.Button("Submit")
106
  output_msg = gr.Textbox(label="Status", interactive=False)
107
  leaderboard_display = gr.DataFrame(
 
1
  import gradio as gr
2
  import pandas as pd
3
  from datasets import load_dataset
4
+ from jiwer import wer, cer, transforms
5
  import os
6
  from datetime import datetime
7
 
8
+ # Define text normalization transform
9
+ transform = transforms.Compose([
10
+ transforms.RemovePunctuation(),
11
+ transforms.ToLowerCase(),
12
+ transforms.RemoveWhiteSpace(replace_by_space=True),
13
+ ])
14
+
15
  # Load the Bambara ASR dataset
16
  dataset = load_dataset("sudoping01/bambara-asr-benchmark", name="default")["train"]
17
  references = {row["id"]: row["text"] for row in dataset}
 
21
  if not os.path.exists(leaderboard_file):
22
  pd.DataFrame(columns=["submitter", "WER", "CER", "timestamp"]).to_csv(leaderboard_file, index=False)
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def process_submission(submitter_name, csv_file):
25
  try:
26
  # Read and validate the uploaded CSV
27
  df = pd.read_csv(csv_file)
 
28
  if set(df.columns) != {"id", "text"}:
29
  return "Error: CSV must contain exactly 'id' and 'text' columns.", None
 
30
  if df["id"].duplicated().any():
31
  return "Error: Duplicate 'id's found in the CSV.", None
 
32
  if set(df["id"]) != set(references.keys()):
33
  return "Error: CSV 'id's must match the dataset 'id's.", None
34
+
35
  # Calculate WER and CER for each prediction
36
  wers, cers = [], []
 
37
  for _, row in df.iterrows():
38
+ ref = references[row["id"]]
39
+ pred = row["text"]
40
+ wers.append(wer(ref, pred, truth_transform=transform, hypothesis_transform=transform))
41
+ cers.append(cer(ref, pred, truth_transform=transform, hypothesis_transform=transform))
42
+
 
 
 
 
 
 
 
43
  # Compute average WER and CER
 
 
 
44
  avg_wer = sum(wers) / len(wers)
45
  avg_cer = sum(cers) / len(cers)
46
 
 
55
  leaderboard.to_csv(leaderboard_file, index=False)
56
 
57
  return "Submission processed successfully!", leaderboard
 
58
  except Exception as e:
59
  return f"Error processing submission: {str(e)}", None
60
 
 
63
  gr.Markdown(
64
  """
65
  # Bambara ASR Leaderboard
66
+ Upload a CSV file with 'id' and 'text' columns to evaluate your ASR predictions.
67
+ The 'id's must match those in the dataset.
68
  [View the dataset here](https://huggingface.co/datasets/MALIBA-AI/bambara_general_leaderboard_dataset).
69
+
70
  - **WER**: Word Error Rate (lower is better).
71
  - **CER**: Character Error Rate (lower is better).
72
  """
73
  )
 
74
  with gr.Row():
75
  submitter = gr.Textbox(label="Submitter Name or Model Name", placeholder="e.g., MALIBA-AI/asr")
76
  csv_upload = gr.File(label="Upload CSV File", file_types=[".csv"])
 
77
  submit_btn = gr.Button("Submit")
78
  output_msg = gr.Textbox(label="Status", interactive=False)
79
  leaderboard_display = gr.DataFrame(