huckiyang's picture
user
bd783af
import gradio as gr
import pandas as pd
from datasets import load_dataset
import numpy as np
from functools import lru_cache
import re
from collections import Counter
import editdistance
import json
import os
# Cache the dataset loading to avoid reloading on refresh
@lru_cache(maxsize=1)
def load_data():
try:
dataset = load_dataset("GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction", split="test")
return dataset
except Exception:
# Fallback to explicit file path if default loading fails
return load_dataset("parquet",
data_files="https://huggingface.co/datasets/GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction/resolve/main/data/test-00000-of-00001.parquet")
# Storage for user-submitted methods (in-memory for demo purposes)
user_methods = []
# Data file for persistence
USER_DATA_FILE = "user_methods.json"
# Load user methods from file if exists
def load_user_methods():
global user_methods
if os.path.exists(USER_DATA_FILE):
try:
with open(USER_DATA_FILE, 'r') as f:
user_methods = json.load(f)
except Exception as e:
print(f"Error loading user methods: {e}")
user_methods = []
# Save user methods to file
def save_user_methods():
try:
with open(USER_DATA_FILE, 'w') as f:
json.dump(user_methods, f)
except Exception as e:
print(f"Error saving user methods: {e}")
# Try to load user methods at startup
try:
load_user_methods()
except:
pass
# Preprocess text for better WER calculation
def preprocess_text(text):
if not text or not isinstance(text, str):
return ""
text = text.lower()
text = re.sub(r'[^\w\s]', '', text)
text = re.sub(r'\s+', ' ', text).strip()
return text
# N-gram scoring for hypothesis ranking
def score_hypothesis(hypothesis, n=4):
if not hypothesis:
return 0
words = hypothesis.split()
if len(words) < n:
return len(words)
ngrams = []
for i in range(len(words) - n + 1):
ngram = ' '.join(words[i:i+n])
ngrams.append(ngram)
unique_ngrams = len(set(ngrams))
total_ngrams = len(ngrams)
score = len(words) + unique_ngrams/max(1, total_ngrams) * 5
return score
# N-gram ranking approach
def get_best_hypothesis_lm(hypotheses):
if not hypotheses:
return ""
if isinstance(hypotheses, str):
return hypotheses
hypothesis_list = [preprocess_text(h) for h in hypotheses if isinstance(h, str)]
if not hypothesis_list:
return ""
scores = [(score_hypothesis(h), h) for h in hypothesis_list]
best_hypothesis = max(scores, key=lambda x: x[0])[1]
return best_hypothesis
# Subwords voting correction approach
def correct_hypotheses(hypotheses):
if not hypotheses:
return ""
if isinstance(hypotheses, str):
return hypotheses
hypothesis_list = [preprocess_text(h) for h in hypotheses if isinstance(h, str)]
if not hypothesis_list:
return ""
word_lists = [h.split() for h in hypothesis_list]
lengths = [len(words) for words in word_lists]
if not lengths:
return ""
most_common_length = Counter(lengths).most_common(1)[0][0]
filtered_word_lists = [words for words in word_lists if len(words) == most_common_length]
if not filtered_word_lists:
return max(hypothesis_list, key=len)
corrected_words = []
for i in range(most_common_length):
position_words = [words[i] for words in filtered_word_lists]
most_common_word = Counter(position_words).most_common(1)[0][0]
corrected_words.append(most_common_word)
return ' '.join(corrected_words)
# Calculate WER
def calculate_simple_wer(reference, hypothesis):
if not reference or not hypothesis:
return 1.0
ref_words = reference.split()
hyp_words = hypothesis.split()
distance = editdistance.eval(ref_words, hyp_words)
if len(ref_words) == 0:
return 1.0
return float(distance) / float(len(ref_words))
# Calculate WER for a group of examples with multiple methods
def calculate_wer_methods(examples, max_samples=200):
if not examples or len(examples) == 0:
return np.nan, np.nan, np.nan
# Limit sample size for efficiency
if hasattr(examples, 'select'):
items_to_process = examples.select(range(min(max_samples, len(examples))))
else:
items_to_process = examples[:max_samples]
wer_values_no_lm = []
wer_values_lm_ranking = []
wer_values_n_best_correction = []
for ex in items_to_process:
# Get reference transcription
transcription = ex.get("transcription")
if not transcription or not isinstance(transcription, str):
continue
reference = preprocess_text(transcription)
if not reference:
continue
# Get 1-best hypothesis for baseline
input1 = ex.get("input1")
if input1 is None and "hypothesis" in ex and ex["hypothesis"]:
if isinstance(ex["hypothesis"], list) and len(ex["hypothesis"]) > 0:
input1 = ex["hypothesis"][0]
elif isinstance(ex["hypothesis"], str):
input1 = ex["hypothesis"]
# Get n-best hypotheses for other methods
n_best_hypotheses = ex.get("hypothesis", [])
# Method 1: No LM (1-best ASR output)
if input1 and isinstance(input1, str):
no_lm_hyp = preprocess_text(input1)
if no_lm_hyp:
wer_no_lm = calculate_simple_wer(reference, no_lm_hyp)
wer_values_no_lm.append(wer_no_lm)
# Method 2: N-gram ranking
if n_best_hypotheses:
lm_best_hyp = get_best_hypothesis_lm(n_best_hypotheses)
if lm_best_hyp:
wer_lm = calculate_simple_wer(reference, lm_best_hyp)
wer_values_lm_ranking.append(wer_lm)
# Method 3: Subwords voting correction
if n_best_hypotheses:
corrected_hyp = correct_hypotheses(n_best_hypotheses)
if corrected_hyp:
wer_corrected = calculate_simple_wer(reference, corrected_hyp)
wer_values_n_best_correction.append(wer_corrected)
# Calculate average WER for each method
no_lm_wer = np.mean(wer_values_no_lm) if wer_values_no_lm else np.nan
lm_ranking_wer = np.mean(wer_values_lm_ranking) if wer_values_lm_ranking else np.nan
n_best_correction_wer = np.mean(wer_values_n_best_correction) if wer_values_n_best_correction else np.nan
return no_lm_wer, lm_ranking_wer, n_best_correction_wer
# Get WER metrics by source
def get_wer_metrics(dataset):
# Group examples by source
examples_by_source = {}
for ex in dataset:
source = ex.get("source", "unknown")
# Skip all_et05_real as requested
if source == "all_et05_real":
continue
if source not in examples_by_source:
examples_by_source[source] = []
examples_by_source[source].append(ex)
# Get all unique sources
all_sources = sorted(examples_by_source.keys())
# Calculate metrics for each source
source_results = {}
for source in all_sources:
examples = examples_by_source.get(source, [])
count = len(examples)
if count > 0:
no_lm_wer, lm_ranking_wer, n_best_wer = calculate_wer_methods(examples)
else:
no_lm_wer, lm_ranking_wer, n_best_wer = np.nan, np.nan, np.nan
source_results[source] = {
"Count": count,
"No LM Baseline": no_lm_wer,
"N-best LM Ranking": lm_ranking_wer,
"N-best Correction": n_best_wer
}
# Calculate overall metrics
filtered_dataset = [ex for ex in dataset if ex.get("source") != "all_et05_real"]
total_count = len(filtered_dataset)
sample_size = min(500, total_count)
sample_dataset = filtered_dataset[:sample_size]
no_lm_wer, lm_ranking_wer, n_best_wer = calculate_wer_methods(sample_dataset)
source_results["OVERALL"] = {
"Count": total_count,
"No LM Baseline": no_lm_wer,
"N-best LM Ranking": lm_ranking_wer,
"N-best Correction": n_best_wer
}
# Create flat DataFrame with labels in the first column
rows = []
# First add row for number of examples
example_row = {"Methods": "Number of Examples"}
for source in all_sources + ["OVERALL"]:
example_row[source] = source_results[source]["Count"]
rows.append(example_row)
# Then add rows for each WER method with simplified names
no_lm_row = {"Methods": "No LM"}
lm_ranking_row = {"Methods": "N-gram Ranking"}
n_best_row = {"Methods": "Subwords Voting"}
# Add the additional methods from the figure
llama_lora_row = {"Methods": "LLaMA-7B-LoRA"}
nb_oracle_row = {"Methods": "N-best Oracle (o_nb)"}
cp_oracle_row = {"Methods": "Compositional Oracle (o_cp)"}
# Populate the existing methods
for source in all_sources + ["OVERALL"]:
no_lm_row[source] = source_results[source]["No LM Baseline"]
lm_ranking_row[source] = source_results[source]["N-best LM Ranking"]
n_best_row[source] = source_results[source]["N-best Correction"]
# Add hardcoded values for the additional methods based on the figure
# Default to NaN for sources not in the figure
llama_lora_row[source] = np.nan
nb_oracle_row[source] = np.nan
cp_oracle_row[source] = np.nan
# Add hardcoded values from the figure for each source
# CHiME-4
if "test_chime4" in all_sources:
llama_lora_row["test_chime4"] = 6.6 / 100 # Convert from percentage
nb_oracle_row["test_chime4"] = 9.1 / 100
cp_oracle_row["test_chime4"] = 2.8 / 100
# Tedlium-3
if "test_td3" in all_sources:
llama_lora_row["test_td3"] = 4.6 / 100
nb_oracle_row["test_td3"] = 3.0 / 100
cp_oracle_row["test_td3"] = 0.7 / 100
# CommonVoice (CV-accent)
if "test_cv" in all_sources:
llama_lora_row["test_cv"] = 11.0 / 100
nb_oracle_row["test_cv"] = 11.4 / 100
cp_oracle_row["test_cv"] = 7.9 / 100
# SwitchBoard
if "test_swbd" in all_sources:
llama_lora_row["test_swbd"] = 14.1 / 100
nb_oracle_row["test_swbd"] = 12.6 / 100
cp_oracle_row["test_swbd"] = 4.2 / 100
# LRS2
if "test_lrs2" in all_sources:
llama_lora_row["test_lrs2"] = 8.8 / 100
nb_oracle_row["test_lrs2"] = 6.9 / 100
cp_oracle_row["test_lrs2"] = 2.6 / 100
# CORAAL
if "test_coraal" in all_sources:
llama_lora_row["test_coraal"] = 19.2 / 100
nb_oracle_row["test_coraal"] = 21.8 / 100
cp_oracle_row["test_coraal"] = 10.7 / 100
# LibriSpeech Clean
if "test_ls_clean" in all_sources:
llama_lora_row["test_ls_clean"] = 1.7 / 100
nb_oracle_row["test_ls_clean"] = 1.0 / 100
cp_oracle_row["test_ls_clean"] = 0.6 / 100
# LibriSpeech Other
if "test_ls_other" in all_sources:
llama_lora_row["test_ls_other"] = 3.8 / 100
nb_oracle_row["test_ls_other"] = 2.7 / 100
cp_oracle_row["test_ls_other"] = 1.6 / 100
# Calculate overall averages for the three additional methods
llama_values = []
nb_oracle_values = []
cp_oracle_values = []
for source in all_sources:
if pd.notna(llama_lora_row[source]):
llama_values.append(llama_lora_row[source])
if pd.notna(nb_oracle_row[source]):
nb_oracle_values.append(nb_oracle_row[source])
if pd.notna(cp_oracle_row[source]):
cp_oracle_values.append(cp_oracle_row[source])
# Print collected values for debugging
print(f"LLaMA values: {llama_values}")
print(f"N-best Oracle values: {nb_oracle_values}")
print(f"Compositional Oracle values: {cp_oracle_values}")
# Calculate overall values - with hardcoded fallbacks
if llama_values:
llama_overall = np.mean(llama_values)
else:
# Calculate from the table data: average of (6.6, 19.2, 11.0, 8.8, 1.7, 3.8, 14.1, 4.6) / 100
llama_overall = 0.0873 # 8.73%
llama_lora_row["OVERALL"] = llama_overall
if nb_oracle_values:
nb_oracle_overall = np.mean(nb_oracle_values)
else:
# Calculate from the table data: average of (9.1, 21.8, 11.4, 6.9, 1.0, 2.7, 12.6, 3.0) / 100
nb_oracle_overall = 0.0856 # 8.56%
nb_oracle_row["OVERALL"] = nb_oracle_overall
if cp_oracle_values:
cp_oracle_overall = np.mean(cp_oracle_values)
else:
# Calculate from the table data: average of (2.8, 10.7, 7.9, 2.6, 0.6, 1.6, 4.2, 0.7) / 100
cp_oracle_overall = 0.0389 # 3.89%
cp_oracle_row["OVERALL"] = cp_oracle_overall
# Add rows in the desired order
rows.append(no_lm_row)
rows.append(lm_ranking_row)
rows.append(n_best_row)
rows.append(llama_lora_row)
rows.append(nb_oracle_row)
rows.append(cp_oracle_row)
# Add user-submitted methods
for user_method in user_methods:
user_row = {"Methods": user_method["name"]}
for source in all_sources + ["OVERALL"]:
user_row[source] = user_method.get(source, np.nan)
rows.append(user_row)
# Create DataFrame from rows
result_df = pd.DataFrame(rows)
return result_df, all_sources
# Format the dataframe for display, and sort by performance
def format_dataframe(df):
df = df.copy()
# Find the rows containing WER values
wer_row_indices = []
for i, method in enumerate(df["Methods"]):
if method not in ["Number of Examples"]:
wer_row_indices.append(i)
# Format WER values
for idx in wer_row_indices:
for col in df.columns:
if col != "Methods":
value = df.loc[idx, col]
if pd.notna(value):
df.loc[idx, col] = f"{value:.4f}"
else:
df.loc[idx, col] = "N/A"
# Extract the examples row
examples_row = df[df["Methods"] == "Number of Examples"]
# Get the performance rows
performance_rows = df[df["Methods"] != "Number of Examples"]
# Convert the OVERALL column to numeric for sorting
# First, replace 'N/A' with a high value (worse than any real WER)
performance_rows["numeric_overall"] = performance_rows["OVERALL"].replace("N/A", "999")
# Convert to float for sorting
performance_rows["numeric_overall"] = performance_rows["numeric_overall"].astype(float)
# Sort by performance (ascending - lower WER is better)
sorted_performance = performance_rows.sort_values(by="numeric_overall")
# Drop the numeric column used for sorting
sorted_performance = sorted_performance.drop(columns=["numeric_overall"])
# Combine the examples row with the sorted performance rows
result = pd.concat([examples_row, sorted_performance], ignore_index=True)
return result
# Main function to create the leaderboard
def create_leaderboard():
dataset = load_data()
metrics_df, all_sources = get_wer_metrics(dataset)
return format_dataframe(metrics_df), all_sources
# Create the Gradio interface
with gr.Blocks(title="ASR Text Correction Leaderboard") as demo:
gr.Markdown("# HyPoraside N-Best WER Leaderboard (Test Data)")
gr.Markdown("Word Error Rate (WER) metrics for different speech sources with multiple correction approaches")
with gr.Row():
refresh_btn = gr.Button("Refresh Leaderboard")
with gr.Row():
gr.Markdown("### Word Error Rate (WER)")
with gr.Row():
try:
initial_df, all_sources = create_leaderboard()
leaderboard = gr.DataFrame(initial_df)
except Exception:
leaderboard = gr.DataFrame(pd.DataFrame([{"Error": "Error initializing leaderboard"}]))
all_sources = []
gr.Markdown("### Submit Your Method")
gr.Markdown("Enter WER values as percentages (e.g., 5.6 for 5.6% WER)")
with gr.Row():
method_name = gr.Textbox(label="Method Name", placeholder="Enter your method name")
# Create input fields for each source
source_inputs = {}
with gr.Row():
with gr.Column():
for i, source in enumerate(all_sources):
if i < len(all_sources) // 2:
source_inputs[source] = gr.Textbox(label=f"WER for {source}", placeholder="e.g., 5.6")
with gr.Column():
for i, source in enumerate(all_sources):
if i >= len(all_sources) // 2:
source_inputs[source] = gr.Textbox(label=f"WER for {source}", placeholder="e.g., 5.6")
with gr.Row():
submit_btn = gr.Button("Submit Results")
def submit_method(name, *args):
if not name:
return "Please enter a method name", leaderboard
# Convert args to a dictionary of source:value pairs
values = {}
for i, source in enumerate(all_sources):
if i < len(args):
values[source] = args[i]
success = add_user_method(name, **values)
if success:
updated_df, _ = create_leaderboard()
return "Method added successfully!", updated_df
else:
return "Error adding method", leaderboard
def refresh_and_report():
updated_df, _ = create_leaderboard()
return updated_df
# Connect buttons to functions
submit_btn.click(
submit_method,
inputs=[method_name] + list(source_inputs.values()),
outputs=[gr.Textbox(label="Status"), leaderboard]
)
refresh_btn.click(refresh_and_report, outputs=[leaderboard])
# Add a new method to the leaderboard
def add_user_method(name, **values):
# Create a new method entry
method = {"name": name}
# Add values for each source
for source, value in values.items():
if value and value.strip():
try:
# Convert to float and ensure it's a percentage
float_value = float(value)
# Store as decimal (divide by 100 if it's a percentage greater than 1)
method[source] = float_value / 100 if float_value > 1 else float_value
except ValueError:
# Skip invalid values
continue
# Calculate overall average if we have values
if len(method) > 1: # More than just the name
values_list = [v for k, v in method.items() if k != "name" and isinstance(v, (int, float))]
if values_list:
method["OVERALL"] = np.mean(values_list)
# Add to user methods
user_methods.append(method)
# Save to file
save_user_methods()
return True
if __name__ == "__main__":
demo.launch()