huckiyang commited on
Commit
4e73867
·
1 Parent(s): 7ec068d

optz the data loading

Browse files
Files changed (1) hide show
  1. app.py +70 -64
app.py CHANGED
@@ -9,7 +9,20 @@ import traceback
9
  # Cache the dataset loading to avoid reloading on refresh
10
  @lru_cache(maxsize=1)
11
  def load_data():
12
- return load_dataset("GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction")
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  # Calculate WER for a group of examples
15
  def calculate_wer(examples):
@@ -21,11 +34,15 @@ def calculate_wer(examples):
21
  valid_pairs = []
22
  for ex in examples:
23
  try:
 
 
 
 
24
  transcription = ex.get("transcription", "")
25
  input1 = ex.get("input1", "")
26
 
27
- # Only add valid pairs
28
- if transcription and input1:
29
  # Limit text length to avoid potential issues
30
  transcription = transcription.strip()[:1000] # Limit to 1000 chars
31
  input1 = input1.strip()[:1000]
@@ -36,100 +53,93 @@ def calculate_wer(examples):
36
  continue
37
 
38
  if not valid_pairs:
 
39
  return np.nan
40
 
 
 
 
 
41
  # Unzip the pairs in one operation
42
  references, hypotheses = zip(*valid_pairs) if valid_pairs else ([], [])
43
 
44
  # Calculate WER
45
- return jiwer.wer(references, hypotheses)
 
 
 
 
 
 
46
 
47
  except Exception as e:
48
  print(f"Error in calculate_wer: {str(e)}")
49
  print(traceback.format_exc())
50
  return np.nan
51
 
52
- # Get WER metrics by source and split
53
  def get_wer_metrics(dataset):
54
  try:
55
- # Pre-process the data to avoid repeated filtering
56
- train_by_source = {}
57
- test_by_source = {}
58
 
59
- # Group examples by source in a single pass for each split
60
- for ex in dataset["train"]:
61
  try:
62
  source = ex.get("source", "unknown")
63
- if source not in train_by_source:
64
- train_by_source[source] = []
65
- train_by_source[source].append(ex)
66
  except Exception as e:
67
- print(f"Error processing train example: {str(e)}")
68
- continue
69
-
70
- for ex in dataset["test"]:
71
- try:
72
- source = ex.get("source", "unknown")
73
- if source not in test_by_source:
74
- test_by_source[source] = []
75
- test_by_source[source].append(ex)
76
- except Exception as e:
77
- print(f"Error processing test example: {str(e)}")
78
  continue
79
 
80
  # Get all unique sources
81
- all_sources = sorted(set(train_by_source.keys()) | set(test_by_source.keys()))
82
 
83
  # Calculate metrics for each source
84
  results = []
85
  for source in all_sources:
86
  try:
87
- train_examples = train_by_source.get(source, [])
88
- test_examples = test_by_source.get(source, [])
89
-
90
- train_count = len(train_examples)
91
- test_count = len(test_examples)
92
 
93
- train_wer = calculate_wer(train_examples) if train_count > 0 else np.nan
94
- test_wer = calculate_wer(test_examples) if test_count > 0 else np.nan
 
 
 
95
 
96
  results.append({
97
  "Source": source,
98
- "Train Count": train_count,
99
- "Train WER": train_wer,
100
- "Test Count": test_count,
101
- "Test WER": test_wer
102
  })
103
  except Exception as e:
104
  print(f"Error processing source {source}: {str(e)}")
105
  results.append({
106
  "Source": source,
107
- "Train Count": 0,
108
- "Train WER": np.nan,
109
- "Test Count": 0,
110
- "Test WER": np.nan
111
  })
112
 
113
  # Calculate overall metrics once
114
  try:
115
- train_wer = calculate_wer(dataset["train"])
116
- test_wer = calculate_wer(dataset["test"])
 
117
 
118
  results.append({
119
  "Source": "OVERALL",
120
- "Train Count": len(dataset["train"]),
121
- "Train WER": train_wer,
122
- "Test Count": len(dataset["test"]),
123
- "Test WER": test_wer
124
  })
125
  except Exception as e:
126
  print(f"Error calculating overall metrics: {str(e)}")
127
  results.append({
128
  "Source": "OVERALL",
129
- "Train Count": len(dataset["train"]),
130
- "Train WER": np.nan,
131
- "Test Count": len(dataset["test"]),
132
- "Test WER": np.nan
133
  })
134
 
135
  return pd.DataFrame(results)
@@ -145,15 +155,10 @@ def format_dataframe(df):
145
  # Use vectorized operations instead of apply
146
  df = df.copy()
147
 
148
- if "Train WER" in df.columns:
149
- mask = df["Train WER"].notna()
150
- df.loc[mask, "Train WER"] = df.loc[mask, "Train WER"].map(lambda x: f"{x:.4f}")
151
- df.loc[~mask, "Train WER"] = "N/A"
152
-
153
- if "Test WER" in df.columns:
154
- mask = df["Test WER"].notna()
155
- df.loc[mask, "Test WER"] = df.loc[mask, "Test WER"].map(lambda x: f"{x:.4f}")
156
- df.loc[~mask, "Test WER"] = "N/A"
157
 
158
  return df
159
 
@@ -174,15 +179,15 @@ def create_leaderboard():
174
  return pd.DataFrame([{"Error": error_msg}])
175
 
176
  # Create the Gradio interface
177
- with gr.Blocks(title="ASR Text Correction Leaderboard") as demo:
178
- gr.Markdown("# ASR Text Correction Baseline WER Leaderboard")
179
- gr.Markdown("Word Error Rate (WER) metrics for GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction dataset")
180
 
181
  with gr.Row():
182
  refresh_btn = gr.Button("Refresh Leaderboard")
183
 
184
  with gr.Row():
185
- error_output = gr.Textbox(label="Errors (if any)")
186
 
187
  with gr.Row():
188
  try:
@@ -197,7 +202,8 @@ with gr.Blocks(title="ASR Text Correction Leaderboard") as demo:
197
  def refresh_and_report():
198
  try:
199
  df = create_leaderboard()
200
- return df, ""
 
201
  except Exception as e:
202
  error_msg = f"Error refreshing leaderboard: {str(e)}\n{traceback.format_exc()}"
203
  print(error_msg)
 
9
  # Cache the dataset loading to avoid reloading on refresh
10
  @lru_cache(maxsize=1)
11
  def load_data():
12
+ try:
13
+ # Load only the test dataset by specifying the split
14
+ dataset = load_dataset("GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction", split="test")
15
+ return dataset
16
+ except Exception as e:
17
+ print(f"Error loading dataset: {str(e)}")
18
+ # Try loading with explicit file path if the default loading fails
19
+ try:
20
+ dataset = load_dataset("parquet",
21
+ data_files="https://huggingface.co/datasets/GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction/resolve/main/data/test-00000-of-00001.parquet")
22
+ return dataset
23
+ except Exception as e2:
24
+ print(f"Error loading with explicit path: {str(e2)}")
25
+ raise
26
 
27
  # Calculate WER for a group of examples
28
  def calculate_wer(examples):
 
34
  valid_pairs = []
35
  for ex in examples:
36
  try:
37
+ # Print a sample example to debug
38
+ if len(valid_pairs) == 0:
39
+ print(f"Sample example keys: {ex.keys()}")
40
+
41
  transcription = ex.get("transcription", "")
42
  input1 = ex.get("input1", "")
43
 
44
+ # Only add valid pairs with non-empty strings
45
+ if transcription and input1 and isinstance(transcription, str) and isinstance(input1, str):
46
  # Limit text length to avoid potential issues
47
  transcription = transcription.strip()[:1000] # Limit to 1000 chars
48
  input1 = input1.strip()[:1000]
 
53
  continue
54
 
55
  if not valid_pairs:
56
+ print("No valid pairs found for WER calculation")
57
  return np.nan
58
 
59
+ # Print sample pairs for debugging
60
+ print(f"Sample pair for WER calculation: {valid_pairs[0]}")
61
+ print(f"Total valid pairs: {len(valid_pairs)}")
62
+
63
  # Unzip the pairs in one operation
64
  references, hypotheses = zip(*valid_pairs) if valid_pairs else ([], [])
65
 
66
  # Calculate WER
67
+ try:
68
+ wer = jiwer.wer(references, hypotheses)
69
+ print(f"Calculated WER: {wer}")
70
+ return wer
71
+ except Exception as wer_error:
72
+ print(f"Error calculating WER: {str(wer_error)}")
73
+ return np.nan
74
 
75
  except Exception as e:
76
  print(f"Error in calculate_wer: {str(e)}")
77
  print(traceback.format_exc())
78
  return np.nan
79
 
80
+ # Get WER metrics by source
81
  def get_wer_metrics(dataset):
82
  try:
83
+ # Group examples by source
84
+ examples_by_source = {}
 
85
 
86
+ # Process all examples
87
+ for ex in dataset:
88
  try:
89
  source = ex.get("source", "unknown")
90
+ if source not in examples_by_source:
91
+ examples_by_source[source] = []
92
+ examples_by_source[source].append(ex)
93
  except Exception as e:
94
+ print(f"Error processing example: {str(e)}")
 
 
 
 
 
 
 
 
 
 
95
  continue
96
 
97
  # Get all unique sources
98
+ all_sources = sorted(examples_by_source.keys())
99
 
100
  # Calculate metrics for each source
101
  results = []
102
  for source in all_sources:
103
  try:
104
+ examples = examples_by_source.get(source, [])
105
+ count = len(examples)
 
 
 
106
 
107
+ if count > 0:
108
+ print(f"Calculating WER for source {source} with {count} examples")
109
+ wer = calculate_wer(examples)
110
+ else:
111
+ wer = np.nan
112
 
113
  results.append({
114
  "Source": source,
115
+ "Count": count,
116
+ "WER": wer
 
 
117
  })
118
  except Exception as e:
119
  print(f"Error processing source {source}: {str(e)}")
120
  results.append({
121
  "Source": source,
122
+ "Count": 0,
123
+ "WER": np.nan
 
 
124
  })
125
 
126
  # Calculate overall metrics once
127
  try:
128
+ total_count = len(dataset)
129
+ print(f"Calculating overall WER for {total_count} examples")
130
+ overall_wer = calculate_wer(dataset)
131
 
132
  results.append({
133
  "Source": "OVERALL",
134
+ "Count": total_count,
135
+ "WER": overall_wer
 
 
136
  })
137
  except Exception as e:
138
  print(f"Error calculating overall metrics: {str(e)}")
139
  results.append({
140
  "Source": "OVERALL",
141
+ "Count": len(dataset),
142
+ "WER": np.nan
 
 
143
  })
144
 
145
  return pd.DataFrame(results)
 
155
  # Use vectorized operations instead of apply
156
  df = df.copy()
157
 
158
+ if "WER" in df.columns:
159
+ mask = df["WER"].notna()
160
+ df.loc[mask, "WER"] = df.loc[mask, "WER"].map(lambda x: f"{x:.4f}")
161
+ df.loc[~mask, "WER"] = "N/A"
 
 
 
 
 
162
 
163
  return df
164
 
 
179
  return pd.DataFrame([{"Error": error_msg}])
180
 
181
  # Create the Gradio interface
182
+ with gr.Blocks(title="ASR Text Correction Test Leaderboard") as demo:
183
+ gr.Markdown("# ASR Text Correction Baseline WER Leaderboard (Test Data)")
184
+ gr.Markdown("Word Error Rate (WER) metrics for test data in GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction dataset")
185
 
186
  with gr.Row():
187
  refresh_btn = gr.Button("Refresh Leaderboard")
188
 
189
  with gr.Row():
190
+ error_output = gr.Textbox(label="Debug Information", visible=True)
191
 
192
  with gr.Row():
193
  try:
 
202
  def refresh_and_report():
203
  try:
204
  df = create_leaderboard()
205
+ debug_info = "Leaderboard refreshed successfully."
206
+ return df, debug_info
207
  except Exception as e:
208
  error_msg = f"Error refreshing leaderboard: {str(e)}\n{traceback.format_exc()}"
209
  print(error_msg)