File size: 13,867 Bytes
ad915da
 
 
 
 
3c6aeb7
7ec068d
44ea2d4
ad915da
3c6aeb7
 
ad915da
4e73867
 
 
 
 
 
 
 
 
 
 
 
 
 
ad915da
44ea2d4
 
 
 
 
 
 
 
 
 
 
 
92a4ace
d9795b9
 
 
 
 
 
 
 
 
92a4ace
 
 
 
 
 
 
 
 
 
 
 
 
 
d9795b9
 
 
 
 
 
ad915da
 
 
 
 
7ec068d
d9795b9
 
 
 
 
44ea2d4
d9795b9
 
 
 
 
44ea2d4
d9795b9
 
 
 
 
 
 
 
 
 
 
 
 
 
44ea2d4
d9795b9
 
92a4ace
 
44ea2d4
d9795b9
 
 
 
 
 
 
 
92a4ace
7ec068d
d9795b9
 
 
 
 
 
 
 
 
 
4e73867
92a4ace
 
 
 
 
 
 
d9795b9
92a4ace
 
 
 
d9795b9
7ec068d
92a4ace
d9795b9
 
 
92a4ace
 
 
 
 
 
d9795b9
92a4ace
 
 
 
 
 
 
7ec068d
92a4ace
 
7ec068d
 
d9795b9
92a4ace
 
d9795b9
4e73867
7ec068d
 
d9795b9
 
 
ad915da
7ec068d
 
 
fbba242
ad915da
4e73867
ad915da
7ec068d
44ea2d4
 
 
 
 
4e73867
 
ad915da
4e73867
d9795b9
7ec068d
 
c7f8633
 
 
 
4e73867
 
 
7ec068d
d9795b9
7ec068d
ad915da
7ec068d
4e73867
44ea2d4
7ec068d
 
c7f8633
7ec068d
 
4e73867
 
7ec068d
4e73867
44ea2d4
d9795b9
4e73867
 
7ec068d
c7f8633
4e73867
c7f8633
 
7ec068d
 
c7f8633
4e73867
c7f8633
 
7ec068d
c7f8633
7ec068d
c7f8633
 
 
 
 
d9795b9
 
c7f8633
d9795b9
7ec068d
c7f8633
4e73867
c7f8633
 
7ec068d
 
d9795b9
c7f8633
 
 
 
 
 
 
381227f
 
 
 
7ec068d
c7f8633
 
 
 
381227f
 
 
c7f8633
ad915da
7ec068d
 
 
 
ad915da
 
 
7ec068d
 
 
 
381227f
 
 
 
 
 
 
 
c7f8633
381227f
c7f8633
 
381227f
c7f8633
381227f
c7f8633
381227f
7ec068d
 
3c6aeb7
7ec068d
 
 
 
ad915da
 
 
 
 
 
3c6aeb7
ad915da
7ec068d
 
 
ad915da
 
4e73867
 
c7f8633
ad915da
 
 
 
 
44ea2d4
ad915da
7ec068d
 
 
 
 
 
 
 
 
 
 
 
 
44ea2d4
4e73867
7ec068d
 
 
 
 
 
ad915da
 
d9795b9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
import gradio as gr
import pandas as pd
from datasets import load_dataset
import jiwer
import numpy as np
from functools import lru_cache
import traceback
import re

# Cache the dataset loading to avoid reloading on refresh
@lru_cache(maxsize=1)
def load_data():
    try:
        # Load only the test dataset by specifying the split
        dataset = load_dataset("GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction", split="test")
        return dataset
    except Exception as e:
        print(f"Error loading dataset: {str(e)}")
        # Try loading with explicit file path if the default loading fails
        try:
            dataset = load_dataset("parquet", 
                                  data_files="https://huggingface.co/datasets/GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction/resolve/main/data/test-00000-of-00001.parquet")
            return dataset
        except Exception as e2:
            print(f"Error loading with explicit path: {str(e2)}")
            raise

# Preprocess text for better WER calculation
def preprocess_text(text):
    if not text or not isinstance(text, str):
        return ""
    # Convert to lowercase
    text = text.lower()
    # Remove punctuation
    text = re.sub(r'[^\w\s]', '', text)
    # Remove extra whitespace
    text = re.sub(r'\s+', ' ', text).strip()
    return text

# Fix the Levenshtein distance calculation to avoid dependence on jiwer internals
def calculate_simple_wer(reference, hypothesis):
    """Calculate WER using a simple word-based approach"""
    if not reference or not hypothesis:
        return 1.0  # Maximum error if either is empty
        
    # Split into words
    ref_words = reference.split()
    hyp_words = hypothesis.split()
    
    # Use editdistance package instead of jiwer internals
    try:
        import editdistance
        distance = editdistance.eval(ref_words, hyp_words)
    except ImportError:
        # Fallback to simple jiwer calculation
        try:
            # Try using the standard jiwer implementation
            wer_value = jiwer.wer(reference, hypothesis)
            return wer_value
        except Exception:
            # If all else fails, return 1.0 (maximum error)
            print("Error calculating WER - fallback to maximum error")
            return 1.0
    
    # WER calculation
    if len(ref_words) == 0:
        return 1.0
    return float(distance) / float(len(ref_words))

# Calculate WER for a group of examples
def calculate_wer(examples):
    if not examples:
        return 0.0
    
    try:
        # Check if examples is a Dataset or a list
        is_dataset = hasattr(examples, 'features')
        
        # Get the first example for inspection
        if is_dataset and len(examples) > 0:
            example = examples[0]
        elif not is_dataset and len(examples) > 0:
            example = examples[0]
        else:
            print("No examples found")
            return np.nan
            
        print("\n===== EXAMPLE DATA INSPECTION =====")
        print(f"Keys in example: {example.keys()}")
        
        # Try different possible field names
        possible_reference_fields = ["transcription", "reference", "ground_truth", "target"]
        possible_hypothesis_fields = ["input1", "hypothesis", "asr_output", "source_text"]
        
        for field in possible_reference_fields:
            if field in example:
                print(f"Reference field '{field}' found with value: {str(example[field])[:100]}...")
        
        for field in possible_hypothesis_fields:
            if field in example:
                print(f"Hypothesis field '{field}' found with value: {str(example[field])[:100]}...")
        
        # Process each example in the dataset
        wer_values = []
        valid_count = 0
        skipped_count = 0
        
        # Determine how to iterate based on type
        items_to_process = examples
        if is_dataset:
            # Limit to first 200 examples for efficiency
            items_to_process = examples.select(range(min(200, len(examples))))
        else:
            items_to_process = examples[:200]  # First 200 examples
        
        for i, ex in enumerate(items_to_process):
            try:
                # Try to get transcription and input1
                transcription = ex.get("transcription")
                
                # First try input1, then use first element from hypothesis if available
                input1 = ex.get("input1")
                if input1 is None and "hypothesis" in ex and ex["hypothesis"]:
                    if isinstance(ex["hypothesis"], list) and len(ex["hypothesis"]) > 0:
                        input1 = ex["hypothesis"][0]
                    elif isinstance(ex["hypothesis"], str):
                        input1 = ex["hypothesis"]
                
                # Print debug info for a few examples
                if i < 3:
                    print(f"\nExample {i} inspection:")
                    print(f"  transcription: {transcription}")
                    print(f"  input1: {input1}")
                    print(f"  type checks: transcription={type(transcription)}, input1={type(input1)}")
                
                # Skip if either field is missing
                if transcription is None or input1 is None:
                    skipped_count += 1
                    if i < 3:
                        print(f"  SKIPPED: Missing field (transcription={transcription is None}, input1={input1 is None})")
                    continue
                
                # Skip if either field is empty after preprocessing
                reference = preprocess_text(transcription)
                hypothesis = preprocess_text(input1)
                
                if not reference or not hypothesis:
                    skipped_count += 1
                    if i < 3:
                        print(f"  SKIPPED: Empty after preprocessing (reference='{reference}', hypothesis='{hypothesis}')")
                    continue
                
                # Calculate WER for this pair
                pair_wer = calculate_simple_wer(reference, hypothesis)
                wer_values.append(pair_wer)
                valid_count += 1
                
                if i < 3:
                    print(f"  VALID PAIR: reference='{reference}', hypothesis='{hypothesis}', WER={pair_wer:.4f}")
                
            except Exception as ex_error:
                print(f"Error processing example {i}: {str(ex_error)}")
                skipped_count += 1
                continue
        
        # Calculate average WER
        print(f"\nProcessing summary: Valid pairs: {valid_count}, Skipped: {skipped_count}")
        
        if not wer_values:
            print("No valid pairs found for WER calculation")
            return np.nan
        
        avg_wer = np.mean(wer_values)
        print(f"Calculated {len(wer_values)} pairs with average WER: {avg_wer:.4f}")
        return avg_wer
    
    except Exception as e:
        print(f"Error in calculate_wer: {str(e)}")
        print(traceback.format_exc())
        return np.nan

# Get WER metrics by source 
def get_wer_metrics(dataset):
    try:
        # Print dataset info
        print(f"\n===== DATASET INFO =====")
        print(f"Dataset size: {len(dataset)}")
        print(f"Dataset features: {dataset.features}")
        
        # Group examples by source
        examples_by_source = {}
        
        # Process all examples
        for i, ex in enumerate(dataset):
            try:
                source = ex.get("source", "unknown")
                # Skip all_et05_real as requested
                if source == "all_et05_real":
                    continue
                    
                if source not in examples_by_source:
                    examples_by_source[source] = []
                examples_by_source[source].append(ex)
            except Exception as e:
                print(f"Error processing example {i}: {str(e)}")
                continue
        
        # Get all unique sources
        all_sources = sorted(examples_by_source.keys())
        print(f"Found sources: {all_sources}")
        
        # Calculate metrics for each source
        source_results = {}
        for source in all_sources:
            try:
                examples = examples_by_source.get(source, [])
                count = len(examples)
                
                if count > 0:
                    print(f"\nCalculating WER for source {source} with {count} examples")
                    wer = calculate_wer(examples)  # Now handles both lists and datasets
                else:
                    wer = np.nan
                
                source_results[source] = {
                    "Count": count,
                    "No LM Baseline": wer
                }
            except Exception as e:
                print(f"Error processing source {source}: {str(e)}")
                source_results[source] = {
                    "Count": 0,
                    "No LM Baseline": np.nan
                }
        
        # Calculate overall metrics with a sample but excluding all_et05_real
        try:
            # Create a filtered dataset without all_et05_real
            filtered_dataset = [ex for ex in dataset if ex.get("source") != "all_et05_real"]
            total_count = len(filtered_dataset)
            print(f"\nCalculating overall WER with a sample of examples (excluding all_et05_real)")
            
            # Sample for calculation
            sample_size = min(500, total_count)
            sample_dataset = filtered_dataset[:sample_size]
            overall_wer = calculate_wer(sample_dataset)
            
            source_results["OVERALL"] = {
                "Count": total_count,
                "No LM Baseline": overall_wer
            }
        except Exception as e:
            print(f"Error calculating overall metrics: {str(e)}")
            print(traceback.format_exc())
            source_results["OVERALL"] = {
                "Count": len(filtered_dataset),
                "No LM Baseline": np.nan
            }
        
        # Create a transposed DataFrame with metrics as rows and sources as columns
        metrics = ["Count", "No LM Baseline"]
        result_df = pd.DataFrame(index=metrics, columns=["Metric"] + all_sources + ["OVERALL"])
        
        # Add descriptive column
        result_df["Metric"] = ["Number of Examples", "Word Error Rate (WER)"]
        
        for source in all_sources + ["OVERALL"]:
            for metric in metrics:
                result_df.loc[metric, source] = source_results[source][metric]
        
        # Set Metric as index for better display
        result_df = result_df.set_index("Metric")
        
        return result_df
    
    except Exception as e:
        print(f"Error in get_wer_metrics: {str(e)}")
        print(traceback.format_exc())
        return pd.DataFrame([{"Error": str(e)}])

# Format the dataframe for display
def format_dataframe(df):
    try:
        # Use vectorized operations instead of apply
        df = df.copy()
        
        # Find the row containing WER values (now with new index name)
        wer_row_index = None
        for idx in df.index:
            if "WER" in idx or "Error Rate" in idx:
                wer_row_index = idx
                break
        
        if wer_row_index:
            # Convert to object type first to avoid warnings
            df.loc[wer_row_index] = df.loc[wer_row_index].astype(object)
            
            for col in df.columns:
                value = df.loc[wer_row_index, col]
                if pd.notna(value):
                    df.loc[wer_row_index, col] = f"{value:.4f}"
                else:
                    df.loc[wer_row_index, col] = "N/A"
        
        return df
    
    except Exception as e:
        print(f"Error in format_dataframe: {str(e)}")
        print(traceback.format_exc())
        return pd.DataFrame([{"Error": str(e)}])

# Main function to create the leaderboard
def create_leaderboard():
    try:
        dataset = load_data()
        metrics_df = get_wer_metrics(dataset)
        return format_dataframe(metrics_df)
    except Exception as e:
        error_msg = f"Error creating leaderboard: {str(e)}\n{traceback.format_exc()}"
        print(error_msg)
        return pd.DataFrame([{"Error": error_msg}])

# Create the Gradio interface
with gr.Blocks(title="ASR Text Correction Test Leaderboard") as demo:
    gr.Markdown("# ASR Text Correction Baseline WER Leaderboard (Test Data)")
    gr.Markdown("Word Error Rate (WER) metrics for different speech sources with No Language Model baseline")
    
    with gr.Row():
        refresh_btn = gr.Button("Refresh Leaderboard")
    
    with gr.Row():
        error_output = gr.Textbox(label="Debug Information", visible=True, lines=10)
    
    with gr.Row():
        try:
            initial_df = create_leaderboard()
            leaderboard = gr.DataFrame(initial_df)
        except Exception as e:
            error_msg = f"Error initializing leaderboard: {str(e)}\n{traceback.format_exc()}"
            print(error_msg)
            error_output.update(value=error_msg)
            leaderboard = gr.DataFrame(pd.DataFrame([{"Error": error_msg}]))
    
    def refresh_and_report():
        try:
            df = create_leaderboard()
            debug_info = "Leaderboard refreshed successfully. Check console for detailed debug information."
            return df, debug_info
        except Exception as e:
            error_msg = f"Error refreshing leaderboard: {str(e)}\n{traceback.format_exc()}"
            print(error_msg)
            return pd.DataFrame([{"Error": error_msg}]), error_msg
    
    refresh_btn.click(refresh_and_report, outputs=[leaderboard, error_output])

if __name__ == "__main__":
    demo.launch()