huckiyang commited on
Commit
7ec068d
·
1 Parent(s): 3c6aeb7

optz the data loading

Browse files
Files changed (1) hide show
  1. app.py +157 -73
app.py CHANGED
@@ -4,6 +4,7 @@ from datasets import load_dataset
4
  import jiwer
5
  import numpy as np
6
  from functools import lru_cache
 
7
 
8
  # Cache the dataset loading to avoid reloading on refresh
9
  @lru_cache(maxsize=1)
@@ -15,89 +16,151 @@ def calculate_wer(examples):
15
  if not examples:
16
  return 0.0
17
 
18
- # Filter valid examples in a single pass
19
- valid_pairs = [(ex.get("transcription", "").strip(), ex.get("input1", "").strip())
20
- for ex in examples
21
- if ex.get("transcription") and ex.get("input1")]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- if not valid_pairs:
 
 
24
  return np.nan
25
-
26
- # Unzip the pairs in one operation
27
- references, hypotheses = zip(*valid_pairs) if valid_pairs else ([], [])
28
-
29
- # Calculate WER
30
- return jiwer.wer(references, hypotheses)
31
 
32
  # Get WER metrics by source and split
33
  def get_wer_metrics(dataset):
34
- # Pre-process the data to avoid repeated filtering
35
- train_by_source = {}
36
- test_by_source = {}
37
-
38
- # Group examples by source in a single pass for each split
39
- for ex in dataset["train"]:
40
- source = ex["source"]
41
- if source not in train_by_source:
42
- train_by_source[source] = []
43
- train_by_source[source].append(ex)
44
-
45
- for ex in dataset["test"]:
46
- source = ex["source"]
47
- if source not in test_by_source:
48
- test_by_source[source] = []
49
- test_by_source[source].append(ex)
50
-
51
- # Get all unique sources
52
- all_sources = sorted(set(train_by_source.keys()) | set(test_by_source.keys()))
53
-
54
- # Calculate metrics for each source
55
- results = []
56
- for source in all_sources:
57
- train_examples = train_by_source.get(source, [])
58
- test_examples = test_by_source.get(source, [])
59
 
60
- train_count = len(train_examples)
61
- test_count = len(test_examples)
 
 
 
 
 
 
 
 
62
 
63
- train_wer = calculate_wer(train_examples) if train_count > 0 else np.nan
64
- test_wer = calculate_wer(test_examples) if test_count > 0 else np.nan
 
 
 
 
 
 
 
65
 
66
- results.append({
67
- "Source": source,
68
- "Train Count": train_count,
69
- "Train WER": train_wer,
70
- "Test Count": test_count,
71
- "Test WER": test_wer
72
- })
73
-
74
- # Calculate overall metrics once
75
- train_wer = calculate_wer(dataset["train"])
76
- test_wer = calculate_wer(dataset["test"])
77
-
78
- results.append({
79
- "Source": "OVERALL",
80
- "Train Count": len(dataset["train"]),
81
- "Train WER": train_wer,
82
- "Test Count": len(dataset["test"]),
83
- "Test WER": test_wer
84
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- return pd.DataFrame(results)
 
 
 
87
 
88
  # Format the dataframe for display
89
  def format_dataframe(df):
90
- # Use vectorized operations instead of apply
91
- df = df.copy()
92
- mask = df["Train WER"].notna()
93
- df.loc[mask, "Train WER"] = df.loc[mask, "Train WER"].map(lambda x: f"{x:.4f}")
94
- df.loc[~mask, "Train WER"] = "N/A"
95
-
96
- mask = df["Test WER"].notna()
97
- df.loc[mask, "Test WER"] = df.loc[mask, "Test WER"].map(lambda x: f"{x:.4f}")
98
- df.loc[~mask, "Test WER"] = "N/A"
 
 
 
 
 
 
99
 
100
- return df
 
 
 
101
 
102
  # Main function to create the leaderboard
103
  def create_leaderboard():
@@ -106,7 +169,9 @@ def create_leaderboard():
106
  metrics_df = get_wer_metrics(dataset)
107
  return format_dataframe(metrics_df)
108
  except Exception as e:
109
- return pd.DataFrame({"Error": [str(e)]})
 
 
110
 
111
  # Create the Gradio interface
112
  with gr.Blocks(title="ASR Text Correction Leaderboard") as demo:
@@ -117,9 +182,28 @@ with gr.Blocks(title="ASR Text Correction Leaderboard") as demo:
117
  refresh_btn = gr.Button("Refresh Leaderboard")
118
 
119
  with gr.Row():
120
- leaderboard = gr.DataFrame(create_leaderboard())
121
 
122
- refresh_btn.click(create_leaderboard, outputs=leaderboard)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  if __name__ == "__main__":
125
  demo.launch()
 
4
  import jiwer
5
  import numpy as np
6
  from functools import lru_cache
7
+ import traceback
8
 
9
  # Cache the dataset loading to avoid reloading on refresh
10
  @lru_cache(maxsize=1)
 
16
  if not examples:
17
  return 0.0
18
 
19
+ try:
20
+ # Filter valid examples in a single pass
21
+ valid_pairs = []
22
+ for ex in examples:
23
+ try:
24
+ transcription = ex.get("transcription", "")
25
+ input1 = ex.get("input1", "")
26
+
27
+ # Only add valid pairs
28
+ if transcription and input1:
29
+ # Limit text length to avoid potential issues
30
+ transcription = transcription.strip()[:1000] # Limit to 1000 chars
31
+ input1 = input1.strip()[:1000]
32
+ valid_pairs.append((transcription, input1))
33
+ except Exception as ex_error:
34
+ # Skip problematic examples but continue processing
35
+ print(f"Error processing example: {str(ex_error)}")
36
+ continue
37
+
38
+ if not valid_pairs:
39
+ return np.nan
40
+
41
+ # Unzip the pairs in one operation
42
+ references, hypotheses = zip(*valid_pairs) if valid_pairs else ([], [])
43
+
44
+ # Calculate WER
45
+ return jiwer.wer(references, hypotheses)
46
 
47
+ except Exception as e:
48
+ print(f"Error in calculate_wer: {str(e)}")
49
+ print(traceback.format_exc())
50
  return np.nan
 
 
 
 
 
 
51
 
52
  # Get WER metrics by source and split
53
  def get_wer_metrics(dataset):
54
+ try:
55
+ # Pre-process the data to avoid repeated filtering
56
+ train_by_source = {}
57
+ test_by_source = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ # Group examples by source in a single pass for each split
60
+ for ex in dataset["train"]:
61
+ try:
62
+ source = ex.get("source", "unknown")
63
+ if source not in train_by_source:
64
+ train_by_source[source] = []
65
+ train_by_source[source].append(ex)
66
+ except Exception as e:
67
+ print(f"Error processing train example: {str(e)}")
68
+ continue
69
 
70
+ for ex in dataset["test"]:
71
+ try:
72
+ source = ex.get("source", "unknown")
73
+ if source not in test_by_source:
74
+ test_by_source[source] = []
75
+ test_by_source[source].append(ex)
76
+ except Exception as e:
77
+ print(f"Error processing test example: {str(e)}")
78
+ continue
79
 
80
+ # Get all unique sources
81
+ all_sources = sorted(set(train_by_source.keys()) | set(test_by_source.keys()))
82
+
83
+ # Calculate metrics for each source
84
+ results = []
85
+ for source in all_sources:
86
+ try:
87
+ train_examples = train_by_source.get(source, [])
88
+ test_examples = test_by_source.get(source, [])
89
+
90
+ train_count = len(train_examples)
91
+ test_count = len(test_examples)
92
+
93
+ train_wer = calculate_wer(train_examples) if train_count > 0 else np.nan
94
+ test_wer = calculate_wer(test_examples) if test_count > 0 else np.nan
95
+
96
+ results.append({
97
+ "Source": source,
98
+ "Train Count": train_count,
99
+ "Train WER": train_wer,
100
+ "Test Count": test_count,
101
+ "Test WER": test_wer
102
+ })
103
+ except Exception as e:
104
+ print(f"Error processing source {source}: {str(e)}")
105
+ results.append({
106
+ "Source": source,
107
+ "Train Count": 0,
108
+ "Train WER": np.nan,
109
+ "Test Count": 0,
110
+ "Test WER": np.nan
111
+ })
112
+
113
+ # Calculate overall metrics once
114
+ try:
115
+ train_wer = calculate_wer(dataset["train"])
116
+ test_wer = calculate_wer(dataset["test"])
117
+
118
+ results.append({
119
+ "Source": "OVERALL",
120
+ "Train Count": len(dataset["train"]),
121
+ "Train WER": train_wer,
122
+ "Test Count": len(dataset["test"]),
123
+ "Test WER": test_wer
124
+ })
125
+ except Exception as e:
126
+ print(f"Error calculating overall metrics: {str(e)}")
127
+ results.append({
128
+ "Source": "OVERALL",
129
+ "Train Count": len(dataset["train"]),
130
+ "Train WER": np.nan,
131
+ "Test Count": len(dataset["test"]),
132
+ "Test WER": np.nan
133
+ })
134
+
135
+ return pd.DataFrame(results)
136
 
137
+ except Exception as e:
138
+ print(f"Error in get_wer_metrics: {str(e)}")
139
+ print(traceback.format_exc())
140
+ return pd.DataFrame([{"Error": str(e)}])
141
 
142
  # Format the dataframe for display
143
  def format_dataframe(df):
144
+ try:
145
+ # Use vectorized operations instead of apply
146
+ df = df.copy()
147
+
148
+ if "Train WER" in df.columns:
149
+ mask = df["Train WER"].notna()
150
+ df.loc[mask, "Train WER"] = df.loc[mask, "Train WER"].map(lambda x: f"{x:.4f}")
151
+ df.loc[~mask, "Train WER"] = "N/A"
152
+
153
+ if "Test WER" in df.columns:
154
+ mask = df["Test WER"].notna()
155
+ df.loc[mask, "Test WER"] = df.loc[mask, "Test WER"].map(lambda x: f"{x:.4f}")
156
+ df.loc[~mask, "Test WER"] = "N/A"
157
+
158
+ return df
159
 
160
+ except Exception as e:
161
+ print(f"Error in format_dataframe: {str(e)}")
162
+ print(traceback.format_exc())
163
+ return pd.DataFrame([{"Error": str(e)}])
164
 
165
  # Main function to create the leaderboard
166
  def create_leaderboard():
 
169
  metrics_df = get_wer_metrics(dataset)
170
  return format_dataframe(metrics_df)
171
  except Exception as e:
172
+ error_msg = f"Error creating leaderboard: {str(e)}\n{traceback.format_exc()}"
173
+ print(error_msg)
174
+ return pd.DataFrame([{"Error": error_msg}])
175
 
176
  # Create the Gradio interface
177
  with gr.Blocks(title="ASR Text Correction Leaderboard") as demo:
 
182
  refresh_btn = gr.Button("Refresh Leaderboard")
183
 
184
  with gr.Row():
185
+ error_output = gr.Textbox(label="Errors (if any)")
186
 
187
+ with gr.Row():
188
+ try:
189
+ initial_df = create_leaderboard()
190
+ leaderboard = gr.DataFrame(initial_df)
191
+ except Exception as e:
192
+ error_msg = f"Error initializing leaderboard: {str(e)}\n{traceback.format_exc()}"
193
+ print(error_msg)
194
+ error_output.update(value=error_msg)
195
+ leaderboard = gr.DataFrame(pd.DataFrame([{"Error": error_msg}]))
196
+
197
+ def refresh_and_report():
198
+ try:
199
+ df = create_leaderboard()
200
+ return df, ""
201
+ except Exception as e:
202
+ error_msg = f"Error refreshing leaderboard: {str(e)}\n{traceback.format_exc()}"
203
+ print(error_msg)
204
+ return pd.DataFrame([{"Error": error_msg}]), error_msg
205
+
206
+ refresh_btn.click(refresh_and_report, outputs=[leaderboard, error_output])
207
 
208
  if __name__ == "__main__":
209
  demo.launch()