|
import gradio as gr |
|
import pandas as pd |
|
from datasets import load_dataset |
|
import jiwer |
|
import numpy as np |
|
from functools import lru_cache |
|
import traceback |
|
import re |
|
|
|
|
|
@lru_cache(maxsize=1) |
|
def load_data(): |
|
try: |
|
|
|
dataset = load_dataset("GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction", split="test") |
|
return dataset |
|
except Exception as e: |
|
print(f"Error loading dataset: {str(e)}") |
|
|
|
try: |
|
dataset = load_dataset("parquet", |
|
data_files="https://huggingface.co/datasets/GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction/resolve/main/data/test-00000-of-00001.parquet") |
|
return dataset |
|
except Exception as e2: |
|
print(f"Error loading with explicit path: {str(e2)}") |
|
raise |
|
|
|
|
|
def preprocess_text(text): |
|
if not text or not isinstance(text, str): |
|
return "" |
|
|
|
text = text.lower() |
|
|
|
text = re.sub(r'[^\w\s]', '', text) |
|
|
|
text = re.sub(r'\s+', ' ', text).strip() |
|
return text |
|
|
|
|
|
def calculate_simple_wer(reference, hypothesis): |
|
"""Calculate WER using a simple word-based approach""" |
|
if not reference or not hypothesis: |
|
return 1.0 |
|
|
|
|
|
ref_words = reference.split() |
|
hyp_words = hypothesis.split() |
|
|
|
|
|
try: |
|
import editdistance |
|
distance = editdistance.eval(ref_words, hyp_words) |
|
except ImportError: |
|
|
|
try: |
|
|
|
wer_value = jiwer.wer(reference, hypothesis) |
|
return wer_value |
|
except Exception: |
|
|
|
print("Error calculating WER - fallback to maximum error") |
|
return 1.0 |
|
|
|
|
|
if len(ref_words) == 0: |
|
return 1.0 |
|
return float(distance) / float(len(ref_words)) |
|
|
|
|
|
def calculate_wer(examples): |
|
if not examples: |
|
return 0.0 |
|
|
|
try: |
|
|
|
is_dataset = hasattr(examples, 'features') |
|
|
|
|
|
if is_dataset and len(examples) > 0: |
|
example = examples[0] |
|
elif not is_dataset and len(examples) > 0: |
|
example = examples[0] |
|
else: |
|
print("No examples found") |
|
return np.nan |
|
|
|
print("\n===== EXAMPLE DATA INSPECTION =====") |
|
print(f"Keys in example: {example.keys()}") |
|
|
|
|
|
possible_reference_fields = ["transcription", "reference", "ground_truth", "target"] |
|
possible_hypothesis_fields = ["input1", "hypothesis", "asr_output", "source_text"] |
|
|
|
for field in possible_reference_fields: |
|
if field in example: |
|
print(f"Reference field '{field}' found with value: {str(example[field])[:100]}...") |
|
|
|
for field in possible_hypothesis_fields: |
|
if field in example: |
|
print(f"Hypothesis field '{field}' found with value: {str(example[field])[:100]}...") |
|
|
|
|
|
wer_values = [] |
|
valid_count = 0 |
|
skipped_count = 0 |
|
|
|
|
|
items_to_process = examples |
|
if is_dataset: |
|
|
|
items_to_process = examples.select(range(min(200, len(examples)))) |
|
else: |
|
items_to_process = examples[:200] |
|
|
|
for i, ex in enumerate(items_to_process): |
|
try: |
|
|
|
transcription = ex.get("transcription") |
|
|
|
|
|
input1 = ex.get("input1") |
|
if input1 is None and "hypothesis" in ex and ex["hypothesis"]: |
|
if isinstance(ex["hypothesis"], list) and len(ex["hypothesis"]) > 0: |
|
input1 = ex["hypothesis"][0] |
|
elif isinstance(ex["hypothesis"], str): |
|
input1 = ex["hypothesis"] |
|
|
|
|
|
if i < 3: |
|
print(f"\nExample {i} inspection:") |
|
print(f" transcription: {transcription}") |
|
print(f" input1: {input1}") |
|
print(f" type checks: transcription={type(transcription)}, input1={type(input1)}") |
|
|
|
|
|
if transcription is None or input1 is None: |
|
skipped_count += 1 |
|
if i < 3: |
|
print(f" SKIPPED: Missing field (transcription={transcription is None}, input1={input1 is None})") |
|
continue |
|
|
|
|
|
reference = preprocess_text(transcription) |
|
hypothesis = preprocess_text(input1) |
|
|
|
if not reference or not hypothesis: |
|
skipped_count += 1 |
|
if i < 3: |
|
print(f" SKIPPED: Empty after preprocessing (reference='{reference}', hypothesis='{hypothesis}')") |
|
continue |
|
|
|
|
|
pair_wer = calculate_simple_wer(reference, hypothesis) |
|
wer_values.append(pair_wer) |
|
valid_count += 1 |
|
|
|
if i < 3: |
|
print(f" VALID PAIR: reference='{reference}', hypothesis='{hypothesis}', WER={pair_wer:.4f}") |
|
|
|
except Exception as ex_error: |
|
print(f"Error processing example {i}: {str(ex_error)}") |
|
skipped_count += 1 |
|
continue |
|
|
|
|
|
print(f"\nProcessing summary: Valid pairs: {valid_count}, Skipped: {skipped_count}") |
|
|
|
if not wer_values: |
|
print("No valid pairs found for WER calculation") |
|
return np.nan |
|
|
|
avg_wer = np.mean(wer_values) |
|
print(f"Calculated {len(wer_values)} pairs with average WER: {avg_wer:.4f}") |
|
return avg_wer |
|
|
|
except Exception as e: |
|
print(f"Error in calculate_wer: {str(e)}") |
|
print(traceback.format_exc()) |
|
return np.nan |
|
|
|
|
|
def get_wer_metrics(dataset): |
|
try: |
|
|
|
print(f"\n===== DATASET INFO =====") |
|
print(f"Dataset size: {len(dataset)}") |
|
print(f"Dataset features: {dataset.features}") |
|
|
|
|
|
examples_by_source = {} |
|
|
|
|
|
for i, ex in enumerate(dataset): |
|
try: |
|
source = ex.get("source", "unknown") |
|
|
|
if source == "all_et05_real": |
|
continue |
|
|
|
if source not in examples_by_source: |
|
examples_by_source[source] = [] |
|
examples_by_source[source].append(ex) |
|
except Exception as e: |
|
print(f"Error processing example {i}: {str(e)}") |
|
continue |
|
|
|
|
|
all_sources = sorted(examples_by_source.keys()) |
|
print(f"Found sources: {all_sources}") |
|
|
|
|
|
source_results = {} |
|
for source in all_sources: |
|
try: |
|
examples = examples_by_source.get(source, []) |
|
count = len(examples) |
|
|
|
if count > 0: |
|
print(f"\nCalculating WER for source {source} with {count} examples") |
|
wer = calculate_wer(examples) |
|
else: |
|
wer = np.nan |
|
|
|
source_results[source] = { |
|
"Count": count, |
|
"No LM Baseline": wer |
|
} |
|
except Exception as e: |
|
print(f"Error processing source {source}: {str(e)}") |
|
source_results[source] = { |
|
"Count": 0, |
|
"No LM Baseline": np.nan |
|
} |
|
|
|
|
|
try: |
|
|
|
filtered_dataset = [ex for ex in dataset if ex.get("source") != "all_et05_real"] |
|
total_count = len(filtered_dataset) |
|
print(f"\nCalculating overall WER with a sample of examples (excluding all_et05_real)") |
|
|
|
|
|
sample_size = min(500, total_count) |
|
sample_dataset = filtered_dataset[:sample_size] |
|
overall_wer = calculate_wer(sample_dataset) |
|
|
|
source_results["OVERALL"] = { |
|
"Count": total_count, |
|
"No LM Baseline": overall_wer |
|
} |
|
except Exception as e: |
|
print(f"Error calculating overall metrics: {str(e)}") |
|
print(traceback.format_exc()) |
|
source_results["OVERALL"] = { |
|
"Count": len(filtered_dataset), |
|
"No LM Baseline": np.nan |
|
} |
|
|
|
|
|
metrics = ["Count", "No LM Baseline"] |
|
result_df = pd.DataFrame(index=metrics, columns=["Metric"] + all_sources + ["OVERALL"]) |
|
|
|
|
|
result_df["Metric"] = ["Number of Examples", "Word Error Rate (WER)"] |
|
|
|
for source in all_sources + ["OVERALL"]: |
|
for metric in metrics: |
|
result_df.loc[metric, source] = source_results[source][metric] |
|
|
|
|
|
result_df = result_df.set_index("Metric") |
|
|
|
return result_df |
|
|
|
except Exception as e: |
|
print(f"Error in get_wer_metrics: {str(e)}") |
|
print(traceback.format_exc()) |
|
return pd.DataFrame([{"Error": str(e)}]) |
|
|
|
|
|
def format_dataframe(df): |
|
try: |
|
|
|
df = df.copy() |
|
|
|
|
|
wer_row_index = None |
|
for idx in df.index: |
|
if "WER" in idx or "Error Rate" in idx: |
|
wer_row_index = idx |
|
break |
|
|
|
if wer_row_index: |
|
|
|
df.loc[wer_row_index] = df.loc[wer_row_index].astype(object) |
|
|
|
for col in df.columns: |
|
value = df.loc[wer_row_index, col] |
|
if pd.notna(value): |
|
df.loc[wer_row_index, col] = f"{value:.4f}" |
|
else: |
|
df.loc[wer_row_index, col] = "N/A" |
|
|
|
return df |
|
|
|
except Exception as e: |
|
print(f"Error in format_dataframe: {str(e)}") |
|
print(traceback.format_exc()) |
|
return pd.DataFrame([{"Error": str(e)}]) |
|
|
|
|
|
def create_leaderboard(): |
|
try: |
|
dataset = load_data() |
|
metrics_df = get_wer_metrics(dataset) |
|
return format_dataframe(metrics_df) |
|
except Exception as e: |
|
error_msg = f"Error creating leaderboard: {str(e)}\n{traceback.format_exc()}" |
|
print(error_msg) |
|
return pd.DataFrame([{"Error": error_msg}]) |
|
|
|
|
|
with gr.Blocks(title="ASR Text Correction Test Leaderboard") as demo: |
|
gr.Markdown("# ASR Text Correction Baseline WER Leaderboard (Test Data)") |
|
gr.Markdown("Word Error Rate (WER) metrics for different speech sources with No Language Model baseline") |
|
|
|
with gr.Row(): |
|
refresh_btn = gr.Button("Refresh Leaderboard") |
|
|
|
with gr.Row(): |
|
error_output = gr.Textbox(label="Debug Information", visible=True, lines=10) |
|
|
|
with gr.Row(): |
|
try: |
|
initial_df = create_leaderboard() |
|
leaderboard = gr.DataFrame(initial_df) |
|
except Exception as e: |
|
error_msg = f"Error initializing leaderboard: {str(e)}\n{traceback.format_exc()}" |
|
print(error_msg) |
|
error_output.update(value=error_msg) |
|
leaderboard = gr.DataFrame(pd.DataFrame([{"Error": error_msg}])) |
|
|
|
def refresh_and_report(): |
|
try: |
|
df = create_leaderboard() |
|
debug_info = "Leaderboard refreshed successfully. Check console for detailed debug information." |
|
return df, debug_info |
|
except Exception as e: |
|
error_msg = f"Error refreshing leaderboard: {str(e)}\n{traceback.format_exc()}" |
|
print(error_msg) |
|
return pd.DataFrame([{"Error": error_msg}]), error_msg |
|
|
|
refresh_btn.click(refresh_and_report, outputs=[leaderboard, error_output]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |