File size: 19,188 Bytes
ad915da 3c6aeb7 44ea2d4 d7d6438 88c90d9 0035b5e ad915da 3c6aeb7 ad915da 4e73867 88c90d9 ad915da 0035b5e 44ea2d4 88c90d9 d7d6438 88c90d9 d7d6438 88c90d9 d7d6438 88c90d9 d7d6438 88c90d9 d7d6438 88c90d9 d7d6438 88c90d9 d7d6438 88c90d9 d9795b9 88c90d9 d9795b9 88c90d9 d9795b9 d7d6438 88c90d9 ad915da 88c90d9 44ea2d4 88c90d9 d9795b9 88c90d9 7ec068d 88c90d9 92a4ace 88c90d9 d7d6438 88c90d9 7ec068d 88c90d9 ad915da 88c90d9 ad915da 88c90d9 ad915da 88c90d9 7ec068d 88c90d9 c7f8633 88c90d9 381227f 88c90d9 a364907 88c90d9 a364907 ad915da 0d06f36 88c90d9 0d06f36 20a8edc 3b279a9 6a06457 3b279a9 0d06f36 88c90d9 0d06f36 88c90d9 0035b5e 88c90d9 0035b5e ad915da 0035b5e ad915da 88c90d9 a364907 88c90d9 3c6aeb7 88c90d9 a364907 88c90d9 0035b5e ad915da 88c90d9 0035b5e ad915da 88c90d9 6821e8c d7d6438 ad915da a364907 7ec068d 0035b5e 7ec068d 88c90d9 0035b5e bd783af 0035b5e bd783af 0035b5e 7ec068d 0035b5e 88c90d9 ad915da bd783af ad915da d9795b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 |
import gradio as gr
import pandas as pd
from datasets import load_dataset
import numpy as np
from functools import lru_cache
import re
from collections import Counter
import editdistance
import json
import os
# Cache the dataset loading to avoid reloading on refresh
@lru_cache(maxsize=1)
def load_data():
try:
dataset = load_dataset("GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction", split="test")
return dataset
except Exception:
# Fallback to explicit file path if default loading fails
return load_dataset("parquet",
data_files="https://huggingface.co/datasets/GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction/resolve/main/data/test-00000-of-00001.parquet")
# Storage for user-submitted methods (in-memory for demo purposes)
user_methods = []
# Data file for persistence
USER_DATA_FILE = "user_methods.json"
# Load user methods from file if exists
def load_user_methods():
global user_methods
if os.path.exists(USER_DATA_FILE):
try:
with open(USER_DATA_FILE, 'r') as f:
user_methods = json.load(f)
except Exception as e:
print(f"Error loading user methods: {e}")
user_methods = []
# Save user methods to file
def save_user_methods():
try:
with open(USER_DATA_FILE, 'w') as f:
json.dump(user_methods, f)
except Exception as e:
print(f"Error saving user methods: {e}")
# Try to load user methods at startup
try:
load_user_methods()
except:
pass
# Preprocess text for better WER calculation
def preprocess_text(text):
if not text or not isinstance(text, str):
return ""
text = text.lower()
text = re.sub(r'[^\w\s]', '', text)
text = re.sub(r'\s+', ' ', text).strip()
return text
# N-gram scoring for hypothesis ranking
def score_hypothesis(hypothesis, n=4):
if not hypothesis:
return 0
words = hypothesis.split()
if len(words) < n:
return len(words)
ngrams = []
for i in range(len(words) - n + 1):
ngram = ' '.join(words[i:i+n])
ngrams.append(ngram)
unique_ngrams = len(set(ngrams))
total_ngrams = len(ngrams)
score = len(words) + unique_ngrams/max(1, total_ngrams) * 5
return score
# N-gram ranking approach
def get_best_hypothesis_lm(hypotheses):
if not hypotheses:
return ""
if isinstance(hypotheses, str):
return hypotheses
hypothesis_list = [preprocess_text(h) for h in hypotheses if isinstance(h, str)]
if not hypothesis_list:
return ""
scores = [(score_hypothesis(h), h) for h in hypothesis_list]
best_hypothesis = max(scores, key=lambda x: x[0])[1]
return best_hypothesis
# Subwords voting correction approach
def correct_hypotheses(hypotheses):
if not hypotheses:
return ""
if isinstance(hypotheses, str):
return hypotheses
hypothesis_list = [preprocess_text(h) for h in hypotheses if isinstance(h, str)]
if not hypothesis_list:
return ""
word_lists = [h.split() for h in hypothesis_list]
lengths = [len(words) for words in word_lists]
if not lengths:
return ""
most_common_length = Counter(lengths).most_common(1)[0][0]
filtered_word_lists = [words for words in word_lists if len(words) == most_common_length]
if not filtered_word_lists:
return max(hypothesis_list, key=len)
corrected_words = []
for i in range(most_common_length):
position_words = [words[i] for words in filtered_word_lists]
most_common_word = Counter(position_words).most_common(1)[0][0]
corrected_words.append(most_common_word)
return ' '.join(corrected_words)
# Calculate WER
def calculate_simple_wer(reference, hypothesis):
if not reference or not hypothesis:
return 1.0
ref_words = reference.split()
hyp_words = hypothesis.split()
distance = editdistance.eval(ref_words, hyp_words)
if len(ref_words) == 0:
return 1.0
return float(distance) / float(len(ref_words))
# Calculate WER for a group of examples with multiple methods
def calculate_wer_methods(examples, max_samples=200):
if not examples or len(examples) == 0:
return np.nan, np.nan, np.nan
# Limit sample size for efficiency
if hasattr(examples, 'select'):
items_to_process = examples.select(range(min(max_samples, len(examples))))
else:
items_to_process = examples[:max_samples]
wer_values_no_lm = []
wer_values_lm_ranking = []
wer_values_n_best_correction = []
for ex in items_to_process:
# Get reference transcription
transcription = ex.get("transcription")
if not transcription or not isinstance(transcription, str):
continue
reference = preprocess_text(transcription)
if not reference:
continue
# Get 1-best hypothesis for baseline
input1 = ex.get("input1")
if input1 is None and "hypothesis" in ex and ex["hypothesis"]:
if isinstance(ex["hypothesis"], list) and len(ex["hypothesis"]) > 0:
input1 = ex["hypothesis"][0]
elif isinstance(ex["hypothesis"], str):
input1 = ex["hypothesis"]
# Get n-best hypotheses for other methods
n_best_hypotheses = ex.get("hypothesis", [])
# Method 1: No LM (1-best ASR output)
if input1 and isinstance(input1, str):
no_lm_hyp = preprocess_text(input1)
if no_lm_hyp:
wer_no_lm = calculate_simple_wer(reference, no_lm_hyp)
wer_values_no_lm.append(wer_no_lm)
# Method 2: N-gram ranking
if n_best_hypotheses:
lm_best_hyp = get_best_hypothesis_lm(n_best_hypotheses)
if lm_best_hyp:
wer_lm = calculate_simple_wer(reference, lm_best_hyp)
wer_values_lm_ranking.append(wer_lm)
# Method 3: Subwords voting correction
if n_best_hypotheses:
corrected_hyp = correct_hypotheses(n_best_hypotheses)
if corrected_hyp:
wer_corrected = calculate_simple_wer(reference, corrected_hyp)
wer_values_n_best_correction.append(wer_corrected)
# Calculate average WER for each method
no_lm_wer = np.mean(wer_values_no_lm) if wer_values_no_lm else np.nan
lm_ranking_wer = np.mean(wer_values_lm_ranking) if wer_values_lm_ranking else np.nan
n_best_correction_wer = np.mean(wer_values_n_best_correction) if wer_values_n_best_correction else np.nan
return no_lm_wer, lm_ranking_wer, n_best_correction_wer
# Get WER metrics by source
def get_wer_metrics(dataset):
# Group examples by source
examples_by_source = {}
for ex in dataset:
source = ex.get("source", "unknown")
# Skip all_et05_real as requested
if source == "all_et05_real":
continue
if source not in examples_by_source:
examples_by_source[source] = []
examples_by_source[source].append(ex)
# Get all unique sources
all_sources = sorted(examples_by_source.keys())
# Calculate metrics for each source
source_results = {}
for source in all_sources:
examples = examples_by_source.get(source, [])
count = len(examples)
if count > 0:
no_lm_wer, lm_ranking_wer, n_best_wer = calculate_wer_methods(examples)
else:
no_lm_wer, lm_ranking_wer, n_best_wer = np.nan, np.nan, np.nan
source_results[source] = {
"Count": count,
"No LM Baseline": no_lm_wer,
"N-best LM Ranking": lm_ranking_wer,
"N-best Correction": n_best_wer
}
# Calculate overall metrics
filtered_dataset = [ex for ex in dataset if ex.get("source") != "all_et05_real"]
total_count = len(filtered_dataset)
sample_size = min(500, total_count)
sample_dataset = filtered_dataset[:sample_size]
no_lm_wer, lm_ranking_wer, n_best_wer = calculate_wer_methods(sample_dataset)
source_results["OVERALL"] = {
"Count": total_count,
"No LM Baseline": no_lm_wer,
"N-best LM Ranking": lm_ranking_wer,
"N-best Correction": n_best_wer
}
# Create flat DataFrame with labels in the first column
rows = []
# First add row for number of examples
example_row = {"Methods": "Number of Examples"}
for source in all_sources + ["OVERALL"]:
example_row[source] = source_results[source]["Count"]
rows.append(example_row)
# Then add rows for each WER method with simplified names
no_lm_row = {"Methods": "No LM"}
lm_ranking_row = {"Methods": "N-gram Ranking"}
n_best_row = {"Methods": "Subwords Voting"}
# Add the additional methods from the figure
llama_lora_row = {"Methods": "LLaMA-7B-LoRA"}
nb_oracle_row = {"Methods": "N-best Oracle (o_nb)"}
cp_oracle_row = {"Methods": "Compositional Oracle (o_cp)"}
# Populate the existing methods
for source in all_sources + ["OVERALL"]:
no_lm_row[source] = source_results[source]["No LM Baseline"]
lm_ranking_row[source] = source_results[source]["N-best LM Ranking"]
n_best_row[source] = source_results[source]["N-best Correction"]
# Add hardcoded values for the additional methods based on the figure
# Default to NaN for sources not in the figure
llama_lora_row[source] = np.nan
nb_oracle_row[source] = np.nan
cp_oracle_row[source] = np.nan
# Add hardcoded values from the figure for each source
# CHiME-4
if "test_chime4" in all_sources:
llama_lora_row["test_chime4"] = 6.6 / 100 # Convert from percentage
nb_oracle_row["test_chime4"] = 9.1 / 100
cp_oracle_row["test_chime4"] = 2.8 / 100
# Tedlium-3
if "test_td3" in all_sources:
llama_lora_row["test_td3"] = 4.6 / 100
nb_oracle_row["test_td3"] = 3.0 / 100
cp_oracle_row["test_td3"] = 0.7 / 100
# CommonVoice (CV-accent)
if "test_cv" in all_sources:
llama_lora_row["test_cv"] = 11.0 / 100
nb_oracle_row["test_cv"] = 11.4 / 100
cp_oracle_row["test_cv"] = 7.9 / 100
# SwitchBoard
if "test_swbd" in all_sources:
llama_lora_row["test_swbd"] = 14.1 / 100
nb_oracle_row["test_swbd"] = 12.6 / 100
cp_oracle_row["test_swbd"] = 4.2 / 100
# LRS2
if "test_lrs2" in all_sources:
llama_lora_row["test_lrs2"] = 8.8 / 100
nb_oracle_row["test_lrs2"] = 6.9 / 100
cp_oracle_row["test_lrs2"] = 2.6 / 100
# CORAAL
if "test_coraal" in all_sources:
llama_lora_row["test_coraal"] = 19.2 / 100
nb_oracle_row["test_coraal"] = 21.8 / 100
cp_oracle_row["test_coraal"] = 10.7 / 100
# LibriSpeech Clean
if "test_ls_clean" in all_sources:
llama_lora_row["test_ls_clean"] = 1.7 / 100
nb_oracle_row["test_ls_clean"] = 1.0 / 100
cp_oracle_row["test_ls_clean"] = 0.6 / 100
# LibriSpeech Other
if "test_ls_other" in all_sources:
llama_lora_row["test_ls_other"] = 3.8 / 100
nb_oracle_row["test_ls_other"] = 2.7 / 100
cp_oracle_row["test_ls_other"] = 1.6 / 100
# Calculate overall averages for the three additional methods
llama_values = []
nb_oracle_values = []
cp_oracle_values = []
for source in all_sources:
if pd.notna(llama_lora_row[source]):
llama_values.append(llama_lora_row[source])
if pd.notna(nb_oracle_row[source]):
nb_oracle_values.append(nb_oracle_row[source])
if pd.notna(cp_oracle_row[source]):
cp_oracle_values.append(cp_oracle_row[source])
# Print collected values for debugging
print(f"LLaMA values: {llama_values}")
print(f"N-best Oracle values: {nb_oracle_values}")
print(f"Compositional Oracle values: {cp_oracle_values}")
# Calculate overall values - with hardcoded fallbacks
if llama_values:
llama_overall = np.mean(llama_values)
else:
# Calculate from the table data: average of (6.6, 19.2, 11.0, 8.8, 1.7, 3.8, 14.1, 4.6) / 100
llama_overall = 0.0873 # 8.73%
llama_lora_row["OVERALL"] = llama_overall
if nb_oracle_values:
nb_oracle_overall = np.mean(nb_oracle_values)
else:
# Calculate from the table data: average of (9.1, 21.8, 11.4, 6.9, 1.0, 2.7, 12.6, 3.0) / 100
nb_oracle_overall = 0.0856 # 8.56%
nb_oracle_row["OVERALL"] = nb_oracle_overall
if cp_oracle_values:
cp_oracle_overall = np.mean(cp_oracle_values)
else:
# Calculate from the table data: average of (2.8, 10.7, 7.9, 2.6, 0.6, 1.6, 4.2, 0.7) / 100
cp_oracle_overall = 0.0389 # 3.89%
cp_oracle_row["OVERALL"] = cp_oracle_overall
# Add rows in the desired order
rows.append(no_lm_row)
rows.append(lm_ranking_row)
rows.append(n_best_row)
rows.append(llama_lora_row)
rows.append(nb_oracle_row)
rows.append(cp_oracle_row)
# Add user-submitted methods
for user_method in user_methods:
user_row = {"Methods": user_method["name"]}
for source in all_sources + ["OVERALL"]:
user_row[source] = user_method.get(source, np.nan)
rows.append(user_row)
# Create DataFrame from rows
result_df = pd.DataFrame(rows)
return result_df, all_sources
# Format the dataframe for display, and sort by performance
def format_dataframe(df):
df = df.copy()
# Find the rows containing WER values
wer_row_indices = []
for i, method in enumerate(df["Methods"]):
if method not in ["Number of Examples"]:
wer_row_indices.append(i)
# Format WER values
for idx in wer_row_indices:
for col in df.columns:
if col != "Methods":
value = df.loc[idx, col]
if pd.notna(value):
df.loc[idx, col] = f"{value:.4f}"
else:
df.loc[idx, col] = "N/A"
# Extract the examples row
examples_row = df[df["Methods"] == "Number of Examples"]
# Get the performance rows
performance_rows = df[df["Methods"] != "Number of Examples"]
# Convert the OVERALL column to numeric for sorting
# First, replace 'N/A' with a high value (worse than any real WER)
performance_rows["numeric_overall"] = performance_rows["OVERALL"].replace("N/A", "999")
# Convert to float for sorting
performance_rows["numeric_overall"] = performance_rows["numeric_overall"].astype(float)
# Sort by performance (ascending - lower WER is better)
sorted_performance = performance_rows.sort_values(by="numeric_overall")
# Drop the numeric column used for sorting
sorted_performance = sorted_performance.drop(columns=["numeric_overall"])
# Combine the examples row with the sorted performance rows
result = pd.concat([examples_row, sorted_performance], ignore_index=True)
return result
# Main function to create the leaderboard
def create_leaderboard():
dataset = load_data()
metrics_df, all_sources = get_wer_metrics(dataset)
return format_dataframe(metrics_df), all_sources
# Create the Gradio interface
with gr.Blocks(title="ASR Text Correction Leaderboard") as demo:
gr.Markdown("# HyPoraside N-Best WER Leaderboard (Test Data)")
gr.Markdown("Word Error Rate (WER) metrics for different speech sources with multiple correction approaches")
with gr.Row():
refresh_btn = gr.Button("Refresh Leaderboard")
with gr.Row():
gr.Markdown("### Word Error Rate (WER)")
with gr.Row():
try:
initial_df, all_sources = create_leaderboard()
leaderboard = gr.DataFrame(initial_df)
except Exception:
leaderboard = gr.DataFrame(pd.DataFrame([{"Error": "Error initializing leaderboard"}]))
all_sources = []
gr.Markdown("### Submit Your Method")
gr.Markdown("Enter WER values as percentages (e.g., 5.6 for 5.6% WER)")
with gr.Row():
method_name = gr.Textbox(label="Method Name", placeholder="Enter your method name")
# Create input fields for each source
source_inputs = {}
with gr.Row():
with gr.Column():
for i, source in enumerate(all_sources):
if i < len(all_sources) // 2:
source_inputs[source] = gr.Textbox(label=f"WER for {source}", placeholder="e.g., 5.6")
with gr.Column():
for i, source in enumerate(all_sources):
if i >= len(all_sources) // 2:
source_inputs[source] = gr.Textbox(label=f"WER for {source}", placeholder="e.g., 5.6")
with gr.Row():
submit_btn = gr.Button("Submit Results")
def submit_method(name, *args):
if not name:
return "Please enter a method name", leaderboard
# Convert args to a dictionary of source:value pairs
values = {}
for i, source in enumerate(all_sources):
if i < len(args):
values[source] = args[i]
success = add_user_method(name, **values)
if success:
updated_df, _ = create_leaderboard()
return "Method added successfully!", updated_df
else:
return "Error adding method", leaderboard
def refresh_and_report():
updated_df, _ = create_leaderboard()
return updated_df
# Connect buttons to functions
submit_btn.click(
submit_method,
inputs=[method_name] + list(source_inputs.values()),
outputs=[gr.Textbox(label="Status"), leaderboard]
)
refresh_btn.click(refresh_and_report, outputs=[leaderboard])
# Add a new method to the leaderboard
def add_user_method(name, **values):
# Create a new method entry
method = {"name": name}
# Add values for each source
for source, value in values.items():
if value and value.strip():
try:
# Convert to float and ensure it's a percentage
float_value = float(value)
# Store as decimal (divide by 100 if it's a percentage greater than 1)
method[source] = float_value / 100 if float_value > 1 else float_value
except ValueError:
# Skip invalid values
continue
# Calculate overall average if we have values
if len(method) > 1: # More than just the name
values_list = [v for k, v in method.items() if k != "name" and isinstance(v, (int, float))]
if values_list:
method["OVERALL"] = np.mean(values_list)
# Add to user methods
user_methods.append(method)
# Save to file
save_user_methods()
return True
if __name__ == "__main__":
demo.launch() |