huckiyang commited on
Commit
44ea2d4
·
1 Parent(s): 4e73867

optz the data loading

Browse files
Files changed (1) hide show
  1. app.py +109 -27
app.py CHANGED
@@ -5,6 +5,7 @@ import jiwer
5
  import numpy as np
6
  from functools import lru_cache
7
  import traceback
 
8
 
9
  # Cache the dataset loading to avoid reloading on refresh
10
  @lru_cache(maxsize=1)
@@ -24,31 +25,69 @@ def load_data():
24
  print(f"Error loading with explicit path: {str(e2)}")
25
  raise
26
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # Calculate WER for a group of examples
28
  def calculate_wer(examples):
29
  if not examples:
30
  return 0.0
31
 
32
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  # Filter valid examples in a single pass
34
  valid_pairs = []
 
35
  for ex in examples:
36
  try:
37
- # Print a sample example to debug
38
- if len(valid_pairs) == 0:
39
- print(f"Sample example keys: {ex.keys()}")
 
 
 
 
 
 
 
 
 
 
40
 
41
- transcription = ex.get("transcription", "")
42
- input1 = ex.get("input1", "")
 
43
 
44
- # Only add valid pairs with non-empty strings
45
- if transcription and input1 and isinstance(transcription, str) and isinstance(input1, str):
46
- # Limit text length to avoid potential issues
47
- transcription = transcription.strip()[:1000] # Limit to 1000 chars
48
- input1 = input1.strip()[:1000]
49
- valid_pairs.append((transcription, input1))
50
  except Exception as ex_error:
51
- # Skip problematic examples but continue processing
52
  print(f"Error processing example: {str(ex_error)}")
53
  continue
54
 
@@ -57,20 +96,55 @@ def calculate_wer(examples):
57
  return np.nan
58
 
59
  # Print sample pairs for debugging
60
- print(f"Sample pair for WER calculation: {valid_pairs[0]}")
 
 
61
  print(f"Total valid pairs: {len(valid_pairs)}")
62
 
63
- # Unzip the pairs in one operation
64
- references, hypotheses = zip(*valid_pairs) if valid_pairs else ([], [])
 
 
 
 
65
 
66
- # Calculate WER
 
 
 
67
  try:
68
- wer = jiwer.wer(references, hypotheses)
69
- print(f"Calculated WER: {wer}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  return wer
71
  except Exception as wer_error:
72
- print(f"Error calculating WER: {str(wer_error)}")
73
- return np.nan
 
 
 
 
 
 
 
 
 
 
74
 
75
  except Exception as e:
76
  print(f"Error in calculate_wer: {str(e)}")
@@ -80,6 +154,11 @@ def calculate_wer(examples):
80
  # Get WER metrics by source
81
  def get_wer_metrics(dataset):
82
  try:
 
 
 
 
 
83
  # Group examples by source
84
  examples_by_source = {}
85
 
@@ -96,6 +175,7 @@ def get_wer_metrics(dataset):
96
 
97
  # Get all unique sources
98
  all_sources = sorted(examples_by_source.keys())
 
99
 
100
  # Calculate metrics for each source
101
  results = []
@@ -105,8 +185,8 @@ def get_wer_metrics(dataset):
105
  count = len(examples)
106
 
107
  if count > 0:
108
- print(f"Calculating WER for source {source} with {count} examples")
109
- wer = calculate_wer(examples)
110
  else:
111
  wer = np.nan
112
 
@@ -123,11 +203,13 @@ def get_wer_metrics(dataset):
123
  "WER": np.nan
124
  })
125
 
126
- # Calculate overall metrics once
127
  try:
128
  total_count = len(dataset)
129
- print(f"Calculating overall WER for {total_count} examples")
130
- overall_wer = calculate_wer(dataset)
 
 
131
 
132
  results.append({
133
  "Source": "OVERALL",
@@ -187,7 +269,7 @@ with gr.Blocks(title="ASR Text Correction Test Leaderboard") as demo:
187
  refresh_btn = gr.Button("Refresh Leaderboard")
188
 
189
  with gr.Row():
190
- error_output = gr.Textbox(label="Debug Information", visible=True)
191
 
192
  with gr.Row():
193
  try:
@@ -202,7 +284,7 @@ with gr.Blocks(title="ASR Text Correction Test Leaderboard") as demo:
202
  def refresh_and_report():
203
  try:
204
  df = create_leaderboard()
205
- debug_info = "Leaderboard refreshed successfully."
206
  return df, debug_info
207
  except Exception as e:
208
  error_msg = f"Error refreshing leaderboard: {str(e)}\n{traceback.format_exc()}"
 
5
  import numpy as np
6
  from functools import lru_cache
7
  import traceback
8
+ import re
9
 
10
  # Cache the dataset loading to avoid reloading on refresh
11
  @lru_cache(maxsize=1)
 
25
  print(f"Error loading with explicit path: {str(e2)}")
26
  raise
27
 
28
+ # Preprocess text for better WER calculation
29
+ def preprocess_text(text):
30
+ if not text or not isinstance(text, str):
31
+ return ""
32
+ # Convert to lowercase
33
+ text = text.lower()
34
+ # Remove punctuation
35
+ text = re.sub(r'[^\w\s]', '', text)
36
+ # Remove extra whitespace
37
+ text = re.sub(r'\s+', ' ', text).strip()
38
+ return text
39
+
40
  # Calculate WER for a group of examples
41
  def calculate_wer(examples):
42
  if not examples:
43
  return 0.0
44
 
45
  try:
46
+ # First, let's examine the first example in detail
47
+ if examples and len(examples) > 0:
48
+ example = examples[0]
49
+ print("\n===== EXAMPLE DATA INSPECTION =====")
50
+ print(f"Keys in example: {example.keys()}")
51
+
52
+ # Try different possible field names
53
+ possible_reference_fields = ["transcription", "reference", "ground_truth", "target"]
54
+ possible_hypothesis_fields = ["input1", "hypothesis", "asr_output", "source_text"]
55
+
56
+ for field in possible_reference_fields:
57
+ if field in example:
58
+ print(f"Reference field '{field}' found with value: {str(example[field])[:100]}...")
59
+
60
+ for field in possible_hypothesis_fields:
61
+ if field in example:
62
+ print(f"Hypothesis field '{field}' found with value: {str(example[field])[:100]}...")
63
+
64
  # Filter valid examples in a single pass
65
  valid_pairs = []
66
+
67
  for ex in examples:
68
  try:
69
+ # First try the expected field names
70
+ if "transcription" in ex and "input1" in ex:
71
+ reference = ex["transcription"]
72
+ hypothesis = ex["input1"]
73
+ # Try alternate field pairs if the standard ones don't exist
74
+ elif "transcription" in ex and "hypothesis_concatenated" in ex and ex["hypothesis_concatenated"]:
75
+ reference = ex["transcription"]
76
+ hypothesis = ex["hypothesis_concatenated"].split('.')[0] # Take first sentence
77
+ elif "reference" in ex and "hypothesis" in ex:
78
+ reference = ex["reference"]
79
+ hypothesis = ex["hypothesis"]
80
+ else:
81
+ continue # Skip this example if we can't find matching fields
82
 
83
+ # Clean and preprocess the text
84
+ reference = preprocess_text(reference)
85
+ hypothesis = preprocess_text(hypothesis)
86
 
87
+ # Only add if both have valid content
88
+ if reference and hypothesis:
89
+ valid_pairs.append((reference, hypothesis))
 
 
 
90
  except Exception as ex_error:
 
91
  print(f"Error processing example: {str(ex_error)}")
92
  continue
93
 
 
96
  return np.nan
97
 
98
  # Print sample pairs for debugging
99
+ print(f"\nSample pair for WER calculation:")
100
+ print(f"Reference: '{valid_pairs[0][0]}'")
101
+ print(f"Hypothesis: '{valid_pairs[0][1]}'")
102
  print(f"Total valid pairs: {len(valid_pairs)}")
103
 
104
+ # Make sure we have enough valid examples
105
+ if len(valid_pairs) < 5:
106
+ print("WARNING: Very few valid pairs for WER calculation")
107
+ if len(valid_pairs) < 2:
108
+ print("Not enough data for reliable WER calculation")
109
+ return np.nan
110
 
111
+ # Unzip the pairs
112
+ references, hypotheses = zip(*valid_pairs)
113
+
114
+ # Calculate WER with additional transforms
115
  try:
116
+ # Set up transformation pipeline for jiwer
117
+ transformation = jiwer.Compose([
118
+ jiwer.ToLowerCase(),
119
+ jiwer.RemoveMultipleSpaces(),
120
+ jiwer.Strip(),
121
+ jiwer.RemovePunctuation(),
122
+ jiwer.ReduceToListOfWords()
123
+ ])
124
+
125
+ # Calculate WER with transformations
126
+ wer = jiwer.wer(
127
+ references,
128
+ hypotheses,
129
+ truth_transform=transformation,
130
+ hypothesis_transform=transformation
131
+ )
132
+
133
+ print(f"Successfully calculated WER: {wer}")
134
  return wer
135
  except Exception as wer_error:
136
+ print(f"Error calculating WER with jiwer: {str(wer_error)}")
137
+
138
+ # Fallback: Calculate character error rate manually for one sample
139
+ try:
140
+ if valid_pairs:
141
+ ref = valid_pairs[0][0]
142
+ hyp = valid_pairs[0][1]
143
+ distance = jiwer.transforms.cer(ref, hyp)
144
+ print(f"Fallback CER for first sample: {distance}")
145
+ return np.nan
146
+ except:
147
+ return np.nan
148
 
149
  except Exception as e:
150
  print(f"Error in calculate_wer: {str(e)}")
 
154
  # Get WER metrics by source
155
  def get_wer_metrics(dataset):
156
  try:
157
+ # Print dataset info
158
+ print(f"\n===== DATASET INFO =====")
159
+ print(f"Dataset size: {len(dataset)}")
160
+ print(f"Dataset features: {dataset.features}")
161
+
162
  # Group examples by source
163
  examples_by_source = {}
164
 
 
175
 
176
  # Get all unique sources
177
  all_sources = sorted(examples_by_source.keys())
178
+ print(f"Found sources: {all_sources}")
179
 
180
  # Calculate metrics for each source
181
  results = []
 
185
  count = len(examples)
186
 
187
  if count > 0:
188
+ print(f"\nCalculating WER for source {source} with {count} examples")
189
+ wer = calculate_wer(examples[:100]) # Start with a sample for debugging
190
  else:
191
  wer = np.nan
192
 
 
203
  "WER": np.nan
204
  })
205
 
206
+ # Calculate overall metrics with a sample
207
  try:
208
  total_count = len(dataset)
209
+ print(f"\nCalculating overall WER with a sample of examples")
210
+ # Use a sample for overall calculation to avoid overloading
211
+ sample_size = min(1000, total_count)
212
+ overall_wer = calculate_wer(dataset[:sample_size])
213
 
214
  results.append({
215
  "Source": "OVERALL",
 
269
  refresh_btn = gr.Button("Refresh Leaderboard")
270
 
271
  with gr.Row():
272
+ error_output = gr.Textbox(label="Debug Information", visible=True, lines=10)
273
 
274
  with gr.Row():
275
  try:
 
284
  def refresh_and_report():
285
  try:
286
  df = create_leaderboard()
287
+ debug_info = "Leaderboard refreshed successfully. Check console for detailed debug information."
288
  return df, debug_info
289
  except Exception as e:
290
  error_msg = f"Error refreshing leaderboard: {str(e)}\n{traceback.format_exc()}"