huckiyang's picture
optz the data code
381227f
raw
history blame
13.9 kB
import gradio as gr
import pandas as pd
from datasets import load_dataset
import jiwer
import numpy as np
from functools import lru_cache
import traceback
import re
# Cache the dataset loading to avoid reloading on refresh
@lru_cache(maxsize=1)
def load_data():
try:
# Load only the test dataset by specifying the split
dataset = load_dataset("GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction", split="test")
return dataset
except Exception as e:
print(f"Error loading dataset: {str(e)}")
# Try loading with explicit file path if the default loading fails
try:
dataset = load_dataset("parquet",
data_files="https://huggingface.co/datasets/GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction/resolve/main/data/test-00000-of-00001.parquet")
return dataset
except Exception as e2:
print(f"Error loading with explicit path: {str(e2)}")
raise
# Preprocess text for better WER calculation
def preprocess_text(text):
if not text or not isinstance(text, str):
return ""
# Convert to lowercase
text = text.lower()
# Remove punctuation
text = re.sub(r'[^\w\s]', '', text)
# Remove extra whitespace
text = re.sub(r'\s+', ' ', text).strip()
return text
# Fix the Levenshtein distance calculation to avoid dependence on jiwer internals
def calculate_simple_wer(reference, hypothesis):
"""Calculate WER using a simple word-based approach"""
if not reference or not hypothesis:
return 1.0 # Maximum error if either is empty
# Split into words
ref_words = reference.split()
hyp_words = hypothesis.split()
# Use editdistance package instead of jiwer internals
try:
import editdistance
distance = editdistance.eval(ref_words, hyp_words)
except ImportError:
# Fallback to simple jiwer calculation
try:
# Try using the standard jiwer implementation
wer_value = jiwer.wer(reference, hypothesis)
return wer_value
except Exception:
# If all else fails, return 1.0 (maximum error)
print("Error calculating WER - fallback to maximum error")
return 1.0
# WER calculation
if len(ref_words) == 0:
return 1.0
return float(distance) / float(len(ref_words))
# Calculate WER for a group of examples
def calculate_wer(examples):
if not examples:
return 0.0
try:
# Check if examples is a Dataset or a list
is_dataset = hasattr(examples, 'features')
# Get the first example for inspection
if is_dataset and len(examples) > 0:
example = examples[0]
elif not is_dataset and len(examples) > 0:
example = examples[0]
else:
print("No examples found")
return np.nan
print("\n===== EXAMPLE DATA INSPECTION =====")
print(f"Keys in example: {example.keys()}")
# Try different possible field names
possible_reference_fields = ["transcription", "reference", "ground_truth", "target"]
possible_hypothesis_fields = ["input1", "hypothesis", "asr_output", "source_text"]
for field in possible_reference_fields:
if field in example:
print(f"Reference field '{field}' found with value: {str(example[field])[:100]}...")
for field in possible_hypothesis_fields:
if field in example:
print(f"Hypothesis field '{field}' found with value: {str(example[field])[:100]}...")
# Process each example in the dataset
wer_values = []
valid_count = 0
skipped_count = 0
# Determine how to iterate based on type
items_to_process = examples
if is_dataset:
# Limit to first 200 examples for efficiency
items_to_process = examples.select(range(min(200, len(examples))))
else:
items_to_process = examples[:200] # First 200 examples
for i, ex in enumerate(items_to_process):
try:
# Try to get transcription and input1
transcription = ex.get("transcription")
# First try input1, then use first element from hypothesis if available
input1 = ex.get("input1")
if input1 is None and "hypothesis" in ex and ex["hypothesis"]:
if isinstance(ex["hypothesis"], list) and len(ex["hypothesis"]) > 0:
input1 = ex["hypothesis"][0]
elif isinstance(ex["hypothesis"], str):
input1 = ex["hypothesis"]
# Print debug info for a few examples
if i < 3:
print(f"\nExample {i} inspection:")
print(f" transcription: {transcription}")
print(f" input1: {input1}")
print(f" type checks: transcription={type(transcription)}, input1={type(input1)}")
# Skip if either field is missing
if transcription is None or input1 is None:
skipped_count += 1
if i < 3:
print(f" SKIPPED: Missing field (transcription={transcription is None}, input1={input1 is None})")
continue
# Skip if either field is empty after preprocessing
reference = preprocess_text(transcription)
hypothesis = preprocess_text(input1)
if not reference or not hypothesis:
skipped_count += 1
if i < 3:
print(f" SKIPPED: Empty after preprocessing (reference='{reference}', hypothesis='{hypothesis}')")
continue
# Calculate WER for this pair
pair_wer = calculate_simple_wer(reference, hypothesis)
wer_values.append(pair_wer)
valid_count += 1
if i < 3:
print(f" VALID PAIR: reference='{reference}', hypothesis='{hypothesis}', WER={pair_wer:.4f}")
except Exception as ex_error:
print(f"Error processing example {i}: {str(ex_error)}")
skipped_count += 1
continue
# Calculate average WER
print(f"\nProcessing summary: Valid pairs: {valid_count}, Skipped: {skipped_count}")
if not wer_values:
print("No valid pairs found for WER calculation")
return np.nan
avg_wer = np.mean(wer_values)
print(f"Calculated {len(wer_values)} pairs with average WER: {avg_wer:.4f}")
return avg_wer
except Exception as e:
print(f"Error in calculate_wer: {str(e)}")
print(traceback.format_exc())
return np.nan
# Get WER metrics by source
def get_wer_metrics(dataset):
try:
# Print dataset info
print(f"\n===== DATASET INFO =====")
print(f"Dataset size: {len(dataset)}")
print(f"Dataset features: {dataset.features}")
# Group examples by source
examples_by_source = {}
# Process all examples
for i, ex in enumerate(dataset):
try:
source = ex.get("source", "unknown")
# Skip all_et05_real as requested
if source == "all_et05_real":
continue
if source not in examples_by_source:
examples_by_source[source] = []
examples_by_source[source].append(ex)
except Exception as e:
print(f"Error processing example {i}: {str(e)}")
continue
# Get all unique sources
all_sources = sorted(examples_by_source.keys())
print(f"Found sources: {all_sources}")
# Calculate metrics for each source
source_results = {}
for source in all_sources:
try:
examples = examples_by_source.get(source, [])
count = len(examples)
if count > 0:
print(f"\nCalculating WER for source {source} with {count} examples")
wer = calculate_wer(examples) # Now handles both lists and datasets
else:
wer = np.nan
source_results[source] = {
"Count": count,
"No LM Baseline": wer
}
except Exception as e:
print(f"Error processing source {source}: {str(e)}")
source_results[source] = {
"Count": 0,
"No LM Baseline": np.nan
}
# Calculate overall metrics with a sample but excluding all_et05_real
try:
# Create a filtered dataset without all_et05_real
filtered_dataset = [ex for ex in dataset if ex.get("source") != "all_et05_real"]
total_count = len(filtered_dataset)
print(f"\nCalculating overall WER with a sample of examples (excluding all_et05_real)")
# Sample for calculation
sample_size = min(500, total_count)
sample_dataset = filtered_dataset[:sample_size]
overall_wer = calculate_wer(sample_dataset)
source_results["OVERALL"] = {
"Count": total_count,
"No LM Baseline": overall_wer
}
except Exception as e:
print(f"Error calculating overall metrics: {str(e)}")
print(traceback.format_exc())
source_results["OVERALL"] = {
"Count": len(filtered_dataset),
"No LM Baseline": np.nan
}
# Create a transposed DataFrame with metrics as rows and sources as columns
metrics = ["Count", "No LM Baseline"]
result_df = pd.DataFrame(index=metrics, columns=["Metric"] + all_sources + ["OVERALL"])
# Add descriptive column
result_df["Metric"] = ["Number of Examples", "Word Error Rate (WER)"]
for source in all_sources + ["OVERALL"]:
for metric in metrics:
result_df.loc[metric, source] = source_results[source][metric]
# Set Metric as index for better display
result_df = result_df.set_index("Metric")
return result_df
except Exception as e:
print(f"Error in get_wer_metrics: {str(e)}")
print(traceback.format_exc())
return pd.DataFrame([{"Error": str(e)}])
# Format the dataframe for display
def format_dataframe(df):
try:
# Use vectorized operations instead of apply
df = df.copy()
# Find the row containing WER values (now with new index name)
wer_row_index = None
for idx in df.index:
if "WER" in idx or "Error Rate" in idx:
wer_row_index = idx
break
if wer_row_index:
# Convert to object type first to avoid warnings
df.loc[wer_row_index] = df.loc[wer_row_index].astype(object)
for col in df.columns:
value = df.loc[wer_row_index, col]
if pd.notna(value):
df.loc[wer_row_index, col] = f"{value:.4f}"
else:
df.loc[wer_row_index, col] = "N/A"
return df
except Exception as e:
print(f"Error in format_dataframe: {str(e)}")
print(traceback.format_exc())
return pd.DataFrame([{"Error": str(e)}])
# Main function to create the leaderboard
def create_leaderboard():
try:
dataset = load_data()
metrics_df = get_wer_metrics(dataset)
return format_dataframe(metrics_df)
except Exception as e:
error_msg = f"Error creating leaderboard: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return pd.DataFrame([{"Error": error_msg}])
# Create the Gradio interface
with gr.Blocks(title="ASR Text Correction Test Leaderboard") as demo:
gr.Markdown("# ASR Text Correction Baseline WER Leaderboard (Test Data)")
gr.Markdown("Word Error Rate (WER) metrics for different speech sources with No Language Model baseline")
with gr.Row():
refresh_btn = gr.Button("Refresh Leaderboard")
with gr.Row():
error_output = gr.Textbox(label="Debug Information", visible=True, lines=10)
with gr.Row():
try:
initial_df = create_leaderboard()
leaderboard = gr.DataFrame(initial_df)
except Exception as e:
error_msg = f"Error initializing leaderboard: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
error_output.update(value=error_msg)
leaderboard = gr.DataFrame(pd.DataFrame([{"Error": error_msg}]))
def refresh_and_report():
try:
df = create_leaderboard()
debug_info = "Leaderboard refreshed successfully. Check console for detailed debug information."
return df, debug_info
except Exception as e:
error_msg = f"Error refreshing leaderboard: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return pd.DataFrame([{"Error": error_msg}]), error_msg
refresh_btn.click(refresh_and_report, outputs=[leaderboard, error_output])
if __name__ == "__main__":
demo.launch()