import gradio as gr import pandas as pd from datasets import load_dataset import numpy as np from functools import lru_cache import re from collections import Counter import editdistance import json import os # Cache the dataset loading to avoid reloading on refresh @lru_cache(maxsize=1) def load_data(): try: dataset = load_dataset("GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction", split="test") return dataset except Exception: # Fallback to explicit file path if default loading fails return load_dataset("parquet", data_files="https://huggingface.co/datasets/GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction/resolve/main/data/test-00000-of-00001.parquet") # Storage for user-submitted methods (in-memory for demo purposes) user_methods = [] # Data file for persistence USER_DATA_FILE = "user_methods.json" # Load user methods from file if exists def load_user_methods(): global user_methods if os.path.exists(USER_DATA_FILE): try: with open(USER_DATA_FILE, 'r') as f: user_methods = json.load(f) except Exception as e: print(f"Error loading user methods: {e}") user_methods = [] # Save user methods to file def save_user_methods(): try: with open(USER_DATA_FILE, 'w') as f: json.dump(user_methods, f) except Exception as e: print(f"Error saving user methods: {e}") # Try to load user methods at startup try: load_user_methods() except: pass # Preprocess text for better WER calculation def preprocess_text(text): if not text or not isinstance(text, str): return "" text = text.lower() text = re.sub(r'[^\w\s]', '', text) text = re.sub(r'\s+', ' ', text).strip() return text # N-gram scoring for hypothesis ranking def score_hypothesis(hypothesis, n=4): if not hypothesis: return 0 words = hypothesis.split() if len(words) < n: return len(words) ngrams = [] for i in range(len(words) - n + 1): ngram = ' '.join(words[i:i+n]) ngrams.append(ngram) unique_ngrams = len(set(ngrams)) total_ngrams = len(ngrams) score = len(words) + unique_ngrams/max(1, total_ngrams) * 5 return score # N-gram ranking approach def get_best_hypothesis_lm(hypotheses): if not hypotheses: return "" if isinstance(hypotheses, str): return hypotheses hypothesis_list = [preprocess_text(h) for h in hypotheses if isinstance(h, str)] if not hypothesis_list: return "" scores = [(score_hypothesis(h), h) for h in hypothesis_list] best_hypothesis = max(scores, key=lambda x: x[0])[1] return best_hypothesis # Subwords voting correction approach def correct_hypotheses(hypotheses): if not hypotheses: return "" if isinstance(hypotheses, str): return hypotheses hypothesis_list = [preprocess_text(h) for h in hypotheses if isinstance(h, str)] if not hypothesis_list: return "" word_lists = [h.split() for h in hypothesis_list] lengths = [len(words) for words in word_lists] if not lengths: return "" most_common_length = Counter(lengths).most_common(1)[0][0] filtered_word_lists = [words for words in word_lists if len(words) == most_common_length] if not filtered_word_lists: return max(hypothesis_list, key=len) corrected_words = [] for i in range(most_common_length): position_words = [words[i] for words in filtered_word_lists] most_common_word = Counter(position_words).most_common(1)[0][0] corrected_words.append(most_common_word) return ' '.join(corrected_words) # Calculate WER def calculate_simple_wer(reference, hypothesis): if not reference or not hypothesis: return 1.0 ref_words = reference.split() hyp_words = hypothesis.split() distance = editdistance.eval(ref_words, hyp_words) if len(ref_words) == 0: return 1.0 return float(distance) / float(len(ref_words)) # Calculate WER for a group of examples with multiple methods def calculate_wer_methods(examples, max_samples=200): if not examples or len(examples) == 0: return np.nan, np.nan, np.nan # Limit sample size for efficiency if hasattr(examples, 'select'): items_to_process = examples.select(range(min(max_samples, len(examples)))) else: items_to_process = examples[:max_samples] wer_values_no_lm = [] wer_values_lm_ranking = [] wer_values_n_best_correction = [] for ex in items_to_process: # Get reference transcription transcription = ex.get("transcription") if not transcription or not isinstance(transcription, str): continue reference = preprocess_text(transcription) if not reference: continue # Get 1-best hypothesis for baseline input1 = ex.get("input1") if input1 is None and "hypothesis" in ex and ex["hypothesis"]: if isinstance(ex["hypothesis"], list) and len(ex["hypothesis"]) > 0: input1 = ex["hypothesis"][0] elif isinstance(ex["hypothesis"], str): input1 = ex["hypothesis"] # Get n-best hypotheses for other methods n_best_hypotheses = ex.get("hypothesis", []) # Method 1: No LM (1-best ASR output) if input1 and isinstance(input1, str): no_lm_hyp = preprocess_text(input1) if no_lm_hyp: wer_no_lm = calculate_simple_wer(reference, no_lm_hyp) wer_values_no_lm.append(wer_no_lm) # Method 2: N-gram ranking if n_best_hypotheses: lm_best_hyp = get_best_hypothesis_lm(n_best_hypotheses) if lm_best_hyp: wer_lm = calculate_simple_wer(reference, lm_best_hyp) wer_values_lm_ranking.append(wer_lm) # Method 3: Subwords voting correction if n_best_hypotheses: corrected_hyp = correct_hypotheses(n_best_hypotheses) if corrected_hyp: wer_corrected = calculate_simple_wer(reference, corrected_hyp) wer_values_n_best_correction.append(wer_corrected) # Calculate average WER for each method no_lm_wer = np.mean(wer_values_no_lm) if wer_values_no_lm else np.nan lm_ranking_wer = np.mean(wer_values_lm_ranking) if wer_values_lm_ranking else np.nan n_best_correction_wer = np.mean(wer_values_n_best_correction) if wer_values_n_best_correction else np.nan return no_lm_wer, lm_ranking_wer, n_best_correction_wer # Get WER metrics by source def get_wer_metrics(dataset): # Group examples by source examples_by_source = {} for ex in dataset: source = ex.get("source", "unknown") # Skip all_et05_real as requested if source == "all_et05_real": continue if source not in examples_by_source: examples_by_source[source] = [] examples_by_source[source].append(ex) # Get all unique sources all_sources = sorted(examples_by_source.keys()) # Calculate metrics for each source source_results = {} for source in all_sources: examples = examples_by_source.get(source, []) count = len(examples) if count > 0: no_lm_wer, lm_ranking_wer, n_best_wer = calculate_wer_methods(examples) else: no_lm_wer, lm_ranking_wer, n_best_wer = np.nan, np.nan, np.nan source_results[source] = { "Count": count, "No LM Baseline": no_lm_wer, "N-best LM Ranking": lm_ranking_wer, "N-best Correction": n_best_wer } # Calculate overall metrics filtered_dataset = [ex for ex in dataset if ex.get("source") != "all_et05_real"] total_count = len(filtered_dataset) sample_size = min(500, total_count) sample_dataset = filtered_dataset[:sample_size] no_lm_wer, lm_ranking_wer, n_best_wer = calculate_wer_methods(sample_dataset) source_results["OVERALL"] = { "Count": total_count, "No LM Baseline": no_lm_wer, "N-best LM Ranking": lm_ranking_wer, "N-best Correction": n_best_wer } # Create flat DataFrame with labels in the first column rows = [] # First add row for number of examples example_row = {"Methods": "Number of Examples"} for source in all_sources + ["OVERALL"]: example_row[source] = source_results[source]["Count"] rows.append(example_row) # Then add rows for each WER method with simplified names no_lm_row = {"Methods": "No LM"} lm_ranking_row = {"Methods": "N-gram Ranking"} n_best_row = {"Methods": "Subwords Voting"} # Add the additional methods from the figure llama_lora_row = {"Methods": "LLaMA-7B-LoRA"} nb_oracle_row = {"Methods": "N-best Oracle (o_nb)"} cp_oracle_row = {"Methods": "Compositional Oracle (o_cp)"} # Populate the existing methods for source in all_sources + ["OVERALL"]: no_lm_row[source] = source_results[source]["No LM Baseline"] lm_ranking_row[source] = source_results[source]["N-best LM Ranking"] n_best_row[source] = source_results[source]["N-best Correction"] # Add hardcoded values for the additional methods based on the figure # Default to NaN for sources not in the figure llama_lora_row[source] = np.nan nb_oracle_row[source] = np.nan cp_oracle_row[source] = np.nan # Add hardcoded values from the figure for each source # CHiME-4 if "test_chime4" in all_sources: llama_lora_row["test_chime4"] = 6.6 / 100 # Convert from percentage nb_oracle_row["test_chime4"] = 9.1 / 100 cp_oracle_row["test_chime4"] = 2.8 / 100 # Tedlium-3 if "test_td3" in all_sources: llama_lora_row["test_td3"] = 4.6 / 100 nb_oracle_row["test_td3"] = 3.0 / 100 cp_oracle_row["test_td3"] = 0.7 / 100 # CommonVoice (CV-accent) if "test_cv" in all_sources: llama_lora_row["test_cv"] = 11.0 / 100 nb_oracle_row["test_cv"] = 11.4 / 100 cp_oracle_row["test_cv"] = 7.9 / 100 # SwitchBoard if "test_swbd" in all_sources: llama_lora_row["test_swbd"] = 14.1 / 100 nb_oracle_row["test_swbd"] = 12.6 / 100 cp_oracle_row["test_swbd"] = 4.2 / 100 # LRS2 if "test_lrs2" in all_sources: llama_lora_row["test_lrs2"] = 8.8 / 100 nb_oracle_row["test_lrs2"] = 6.9 / 100 cp_oracle_row["test_lrs2"] = 2.6 / 100 # CORAAL if "test_coraal" in all_sources: llama_lora_row["test_coraal"] = 19.2 / 100 nb_oracle_row["test_coraal"] = 21.8 / 100 cp_oracle_row["test_coraal"] = 10.7 / 100 # LibriSpeech Clean if "test_ls_clean" in all_sources: llama_lora_row["test_ls_clean"] = 1.7 / 100 nb_oracle_row["test_ls_clean"] = 1.0 / 100 cp_oracle_row["test_ls_clean"] = 0.6 / 100 # LibriSpeech Other if "test_ls_other" in all_sources: llama_lora_row["test_ls_other"] = 3.8 / 100 nb_oracle_row["test_ls_other"] = 2.7 / 100 cp_oracle_row["test_ls_other"] = 1.6 / 100 # Calculate overall averages for the three additional methods llama_values = [] nb_oracle_values = [] cp_oracle_values = [] for source in all_sources: if pd.notna(llama_lora_row[source]): llama_values.append(llama_lora_row[source]) if pd.notna(nb_oracle_row[source]): nb_oracle_values.append(nb_oracle_row[source]) if pd.notna(cp_oracle_row[source]): cp_oracle_values.append(cp_oracle_row[source]) # Print collected values for debugging print(f"LLaMA values: {llama_values}") print(f"N-best Oracle values: {nb_oracle_values}") print(f"Compositional Oracle values: {cp_oracle_values}") # Calculate overall values - with hardcoded fallbacks if llama_values: llama_overall = np.mean(llama_values) else: # Calculate from the table data: average of (6.6, 19.2, 11.0, 8.8, 1.7, 3.8, 14.1, 4.6) / 100 llama_overall = 0.0873 # 8.73% llama_lora_row["OVERALL"] = llama_overall if nb_oracle_values: nb_oracle_overall = np.mean(nb_oracle_values) else: # Calculate from the table data: average of (9.1, 21.8, 11.4, 6.9, 1.0, 2.7, 12.6, 3.0) / 100 nb_oracle_overall = 0.0856 # 8.56% nb_oracle_row["OVERALL"] = nb_oracle_overall if cp_oracle_values: cp_oracle_overall = np.mean(cp_oracle_values) else: # Calculate from the table data: average of (2.8, 10.7, 7.9, 2.6, 0.6, 1.6, 4.2, 0.7) / 100 cp_oracle_overall = 0.0389 # 3.89% cp_oracle_row["OVERALL"] = cp_oracle_overall # Add rows in the desired order rows.append(no_lm_row) rows.append(lm_ranking_row) rows.append(n_best_row) rows.append(llama_lora_row) rows.append(nb_oracle_row) rows.append(cp_oracle_row) # Add user-submitted methods for user_method in user_methods: user_row = {"Methods": user_method["name"]} for source in all_sources + ["OVERALL"]: user_row[source] = user_method.get(source, np.nan) rows.append(user_row) # Create DataFrame from rows result_df = pd.DataFrame(rows) return result_df, all_sources # Format the dataframe for display, and sort by performance def format_dataframe(df): df = df.copy() # Find the rows containing WER values wer_row_indices = [] for i, method in enumerate(df["Methods"]): if method not in ["Number of Examples"]: wer_row_indices.append(i) # Format WER values for idx in wer_row_indices: for col in df.columns: if col != "Methods": value = df.loc[idx, col] if pd.notna(value): df.loc[idx, col] = f"{value:.4f}" else: df.loc[idx, col] = "N/A" # Extract the examples row examples_row = df[df["Methods"] == "Number of Examples"] # Get the performance rows performance_rows = df[df["Methods"] != "Number of Examples"] # Convert the OVERALL column to numeric for sorting # First, replace 'N/A' with a high value (worse than any real WER) performance_rows["numeric_overall"] = performance_rows["OVERALL"].replace("N/A", "999") # Convert to float for sorting performance_rows["numeric_overall"] = performance_rows["numeric_overall"].astype(float) # Sort by performance (ascending - lower WER is better) sorted_performance = performance_rows.sort_values(by="numeric_overall") # Drop the numeric column used for sorting sorted_performance = sorted_performance.drop(columns=["numeric_overall"]) # Combine the examples row with the sorted performance rows result = pd.concat([examples_row, sorted_performance], ignore_index=True) return result # Main function to create the leaderboard def create_leaderboard(): dataset = load_data() metrics_df, all_sources = get_wer_metrics(dataset) return format_dataframe(metrics_df), all_sources # Create the Gradio interface with gr.Blocks(title="ASR Text Correction Leaderboard") as demo: gr.Markdown("# HyPoraside N-Best WER Leaderboard (Test Data)") gr.Markdown("Word Error Rate (WER) metrics for different speech sources with multiple correction approaches") with gr.Row(): refresh_btn = gr.Button("Refresh Leaderboard") with gr.Row(): gr.Markdown("### Word Error Rate (WER)") with gr.Row(): try: initial_df, all_sources = create_leaderboard() leaderboard = gr.DataFrame(initial_df) except Exception: leaderboard = gr.DataFrame(pd.DataFrame([{"Error": "Error initializing leaderboard"}])) all_sources = [] gr.Markdown("### Submit Your Method") gr.Markdown("Enter WER values as percentages (e.g., 5.6 for 5.6% WER)") with gr.Row(): method_name = gr.Textbox(label="Method Name", placeholder="Enter your method name") # Create input fields for each source source_inputs = {} with gr.Row(): with gr.Column(): for i, source in enumerate(all_sources): if i < len(all_sources) // 2: source_inputs[source] = gr.Textbox(label=f"WER for {source}", placeholder="e.g., 5.6") with gr.Column(): for i, source in enumerate(all_sources): if i >= len(all_sources) // 2: source_inputs[source] = gr.Textbox(label=f"WER for {source}", placeholder="e.g., 5.6") with gr.Row(): submit_btn = gr.Button("Submit Results") def submit_method(name, *args): if not name: return "Please enter a method name", leaderboard # Convert args to a dictionary of source:value pairs values = {} for i, source in enumerate(all_sources): if i < len(args): values[source] = args[i] success = add_user_method(name, **values) if success: updated_df, _ = create_leaderboard() return "Method added successfully!", updated_df else: return "Error adding method", leaderboard def refresh_and_report(): updated_df, _ = create_leaderboard() return updated_df # Connect buttons to functions submit_btn.click( submit_method, inputs=[method_name] + list(source_inputs.values()), outputs=[gr.Textbox(label="Status"), leaderboard] ) refresh_btn.click(refresh_and_report, outputs=[leaderboard]) # Add a new method to the leaderboard def add_user_method(name, **values): # Create a new method entry method = {"name": name} # Add values for each source for source, value in values.items(): if value and value.strip(): try: # Convert to float and ensure it's a percentage float_value = float(value) # Store as decimal (divide by 100 if it's a percentage greater than 1) method[source] = float_value / 100 if float_value > 1 else float_value except ValueError: # Skip invalid values continue # Calculate overall average if we have values if len(method) > 1: # More than just the name values_list = [v for k, v in method.items() if k != "name" and isinstance(v, (int, float))] if values_list: method["OVERALL"] = np.mean(values_list) # Add to user methods user_methods.append(method) # Save to file save_user_methods() return True if __name__ == "__main__": demo.launch()