|
import gradio as gr |
|
import pandas as pd |
|
from datasets import load_dataset |
|
import numpy as np |
|
from functools import lru_cache |
|
import re |
|
from collections import Counter |
|
import editdistance |
|
import json |
|
import os |
|
|
|
|
|
@lru_cache(maxsize=1) |
|
def load_data(): |
|
try: |
|
dataset = load_dataset("GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction", split="test") |
|
return dataset |
|
except Exception: |
|
|
|
return load_dataset("parquet", |
|
data_files="https://huggingface.co/datasets/GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction/resolve/main/data/test-00000-of-00001.parquet") |
|
|
|
|
|
user_methods = [] |
|
|
|
|
|
USER_DATA_FILE = "user_methods.json" |
|
|
|
|
|
def load_user_methods(): |
|
global user_methods |
|
if os.path.exists(USER_DATA_FILE): |
|
try: |
|
with open(USER_DATA_FILE, 'r') as f: |
|
user_methods = json.load(f) |
|
except Exception as e: |
|
print(f"Error loading user methods: {e}") |
|
user_methods = [] |
|
|
|
|
|
def save_user_methods(): |
|
try: |
|
with open(USER_DATA_FILE, 'w') as f: |
|
json.dump(user_methods, f) |
|
except Exception as e: |
|
print(f"Error saving user methods: {e}") |
|
|
|
|
|
try: |
|
load_user_methods() |
|
except: |
|
pass |
|
|
|
|
|
def preprocess_text(text): |
|
if not text or not isinstance(text, str): |
|
return "" |
|
text = text.lower() |
|
text = re.sub(r'[^\w\s]', '', text) |
|
text = re.sub(r'\s+', ' ', text).strip() |
|
return text |
|
|
|
|
|
def score_hypothesis(hypothesis, n=4): |
|
if not hypothesis: |
|
return 0 |
|
|
|
words = hypothesis.split() |
|
if len(words) < n: |
|
return len(words) |
|
|
|
ngrams = [] |
|
for i in range(len(words) - n + 1): |
|
ngram = ' '.join(words[i:i+n]) |
|
ngrams.append(ngram) |
|
|
|
unique_ngrams = len(set(ngrams)) |
|
total_ngrams = len(ngrams) |
|
score = len(words) + unique_ngrams/max(1, total_ngrams) * 5 |
|
return score |
|
|
|
|
|
def get_best_hypothesis_lm(hypotheses): |
|
if not hypotheses: |
|
return "" |
|
|
|
if isinstance(hypotheses, str): |
|
return hypotheses |
|
|
|
hypothesis_list = [preprocess_text(h) for h in hypotheses if isinstance(h, str)] |
|
|
|
if not hypothesis_list: |
|
return "" |
|
|
|
scores = [(score_hypothesis(h), h) for h in hypothesis_list] |
|
best_hypothesis = max(scores, key=lambda x: x[0])[1] |
|
return best_hypothesis |
|
|
|
|
|
def correct_hypotheses(hypotheses): |
|
if not hypotheses: |
|
return "" |
|
|
|
if isinstance(hypotheses, str): |
|
return hypotheses |
|
|
|
hypothesis_list = [preprocess_text(h) for h in hypotheses if isinstance(h, str)] |
|
|
|
if not hypothesis_list: |
|
return "" |
|
|
|
word_lists = [h.split() for h in hypothesis_list] |
|
lengths = [len(words) for words in word_lists] |
|
|
|
if not lengths: |
|
return "" |
|
|
|
most_common_length = Counter(lengths).most_common(1)[0][0] |
|
filtered_word_lists = [words for words in word_lists if len(words) == most_common_length] |
|
|
|
if not filtered_word_lists: |
|
return max(hypothesis_list, key=len) |
|
|
|
corrected_words = [] |
|
for i in range(most_common_length): |
|
position_words = [words[i] for words in filtered_word_lists] |
|
most_common_word = Counter(position_words).most_common(1)[0][0] |
|
corrected_words.append(most_common_word) |
|
|
|
return ' '.join(corrected_words) |
|
|
|
|
|
def calculate_simple_wer(reference, hypothesis): |
|
if not reference or not hypothesis: |
|
return 1.0 |
|
|
|
ref_words = reference.split() |
|
hyp_words = hypothesis.split() |
|
|
|
distance = editdistance.eval(ref_words, hyp_words) |
|
|
|
if len(ref_words) == 0: |
|
return 1.0 |
|
return float(distance) / float(len(ref_words)) |
|
|
|
|
|
def calculate_wer_methods(examples, max_samples=200): |
|
if not examples or len(examples) == 0: |
|
return np.nan, np.nan, np.nan |
|
|
|
|
|
if hasattr(examples, 'select'): |
|
items_to_process = examples.select(range(min(max_samples, len(examples)))) |
|
else: |
|
items_to_process = examples[:max_samples] |
|
|
|
wer_values_no_lm = [] |
|
wer_values_lm_ranking = [] |
|
wer_values_n_best_correction = [] |
|
|
|
for ex in items_to_process: |
|
|
|
transcription = ex.get("transcription") |
|
if not transcription or not isinstance(transcription, str): |
|
continue |
|
|
|
reference = preprocess_text(transcription) |
|
if not reference: |
|
continue |
|
|
|
|
|
input1 = ex.get("input1") |
|
if input1 is None and "hypothesis" in ex and ex["hypothesis"]: |
|
if isinstance(ex["hypothesis"], list) and len(ex["hypothesis"]) > 0: |
|
input1 = ex["hypothesis"][0] |
|
elif isinstance(ex["hypothesis"], str): |
|
input1 = ex["hypothesis"] |
|
|
|
|
|
n_best_hypotheses = ex.get("hypothesis", []) |
|
|
|
|
|
if input1 and isinstance(input1, str): |
|
no_lm_hyp = preprocess_text(input1) |
|
if no_lm_hyp: |
|
wer_no_lm = calculate_simple_wer(reference, no_lm_hyp) |
|
wer_values_no_lm.append(wer_no_lm) |
|
|
|
|
|
if n_best_hypotheses: |
|
lm_best_hyp = get_best_hypothesis_lm(n_best_hypotheses) |
|
if lm_best_hyp: |
|
wer_lm = calculate_simple_wer(reference, lm_best_hyp) |
|
wer_values_lm_ranking.append(wer_lm) |
|
|
|
|
|
if n_best_hypotheses: |
|
corrected_hyp = correct_hypotheses(n_best_hypotheses) |
|
if corrected_hyp: |
|
wer_corrected = calculate_simple_wer(reference, corrected_hyp) |
|
wer_values_n_best_correction.append(wer_corrected) |
|
|
|
|
|
no_lm_wer = np.mean(wer_values_no_lm) if wer_values_no_lm else np.nan |
|
lm_ranking_wer = np.mean(wer_values_lm_ranking) if wer_values_lm_ranking else np.nan |
|
n_best_correction_wer = np.mean(wer_values_n_best_correction) if wer_values_n_best_correction else np.nan |
|
|
|
return no_lm_wer, lm_ranking_wer, n_best_correction_wer |
|
|
|
|
|
def get_wer_metrics(dataset): |
|
|
|
examples_by_source = {} |
|
|
|
for ex in dataset: |
|
source = ex.get("source", "unknown") |
|
|
|
if source == "all_et05_real": |
|
continue |
|
|
|
if source not in examples_by_source: |
|
examples_by_source[source] = [] |
|
examples_by_source[source].append(ex) |
|
|
|
|
|
all_sources = sorted(examples_by_source.keys()) |
|
|
|
|
|
source_results = {} |
|
for source in all_sources: |
|
examples = examples_by_source.get(source, []) |
|
count = len(examples) |
|
|
|
if count > 0: |
|
no_lm_wer, lm_ranking_wer, n_best_wer = calculate_wer_methods(examples) |
|
else: |
|
no_lm_wer, lm_ranking_wer, n_best_wer = np.nan, np.nan, np.nan |
|
|
|
source_results[source] = { |
|
"Count": count, |
|
"No LM Baseline": no_lm_wer, |
|
"N-best LM Ranking": lm_ranking_wer, |
|
"N-best Correction": n_best_wer |
|
} |
|
|
|
|
|
filtered_dataset = [ex for ex in dataset if ex.get("source") != "all_et05_real"] |
|
total_count = len(filtered_dataset) |
|
|
|
sample_size = min(500, total_count) |
|
sample_dataset = filtered_dataset[:sample_size] |
|
no_lm_wer, lm_ranking_wer, n_best_wer = calculate_wer_methods(sample_dataset) |
|
|
|
source_results["OVERALL"] = { |
|
"Count": total_count, |
|
"No LM Baseline": no_lm_wer, |
|
"N-best LM Ranking": lm_ranking_wer, |
|
"N-best Correction": n_best_wer |
|
} |
|
|
|
|
|
rows = [] |
|
|
|
|
|
example_row = {"Methods": "Number of Examples"} |
|
for source in all_sources + ["OVERALL"]: |
|
example_row[source] = source_results[source]["Count"] |
|
rows.append(example_row) |
|
|
|
|
|
no_lm_row = {"Methods": "No LM"} |
|
lm_ranking_row = {"Methods": "N-gram Ranking"} |
|
n_best_row = {"Methods": "Subwords Voting"} |
|
|
|
|
|
llama_lora_row = {"Methods": "LLaMA-7B-LoRA"} |
|
nb_oracle_row = {"Methods": "N-best Oracle (o_nb)"} |
|
cp_oracle_row = {"Methods": "Compositional Oracle (o_cp)"} |
|
|
|
|
|
for source in all_sources + ["OVERALL"]: |
|
no_lm_row[source] = source_results[source]["No LM Baseline"] |
|
lm_ranking_row[source] = source_results[source]["N-best LM Ranking"] |
|
n_best_row[source] = source_results[source]["N-best Correction"] |
|
|
|
|
|
|
|
llama_lora_row[source] = np.nan |
|
nb_oracle_row[source] = np.nan |
|
cp_oracle_row[source] = np.nan |
|
|
|
|
|
|
|
if "test_chime4" in all_sources: |
|
llama_lora_row["test_chime4"] = 6.6 / 100 |
|
nb_oracle_row["test_chime4"] = 9.1 / 100 |
|
cp_oracle_row["test_chime4"] = 2.8 / 100 |
|
|
|
|
|
if "test_td3" in all_sources: |
|
llama_lora_row["test_td3"] = 4.6 / 100 |
|
nb_oracle_row["test_td3"] = 3.0 / 100 |
|
cp_oracle_row["test_td3"] = 0.7 / 100 |
|
|
|
|
|
if "test_cv" in all_sources: |
|
llama_lora_row["test_cv"] = 11.0 / 100 |
|
nb_oracle_row["test_cv"] = 11.4 / 100 |
|
cp_oracle_row["test_cv"] = 7.9 / 100 |
|
|
|
|
|
if "test_swbd" in all_sources: |
|
llama_lora_row["test_swbd"] = 14.1 / 100 |
|
nb_oracle_row["test_swbd"] = 12.6 / 100 |
|
cp_oracle_row["test_swbd"] = 4.2 / 100 |
|
|
|
|
|
if "test_lrs2" in all_sources: |
|
llama_lora_row["test_lrs2"] = 8.8 / 100 |
|
nb_oracle_row["test_lrs2"] = 6.9 / 100 |
|
cp_oracle_row["test_lrs2"] = 2.6 / 100 |
|
|
|
|
|
if "test_coraal" in all_sources: |
|
llama_lora_row["test_coraal"] = 19.2 / 100 |
|
nb_oracle_row["test_coraal"] = 21.8 / 100 |
|
cp_oracle_row["test_coraal"] = 10.7 / 100 |
|
|
|
|
|
if "test_ls_clean" in all_sources: |
|
llama_lora_row["test_ls_clean"] = 1.7 / 100 |
|
nb_oracle_row["test_ls_clean"] = 1.0 / 100 |
|
cp_oracle_row["test_ls_clean"] = 0.6 / 100 |
|
|
|
|
|
if "test_ls_other" in all_sources: |
|
llama_lora_row["test_ls_other"] = 3.8 / 100 |
|
nb_oracle_row["test_ls_other"] = 2.7 / 100 |
|
cp_oracle_row["test_ls_other"] = 1.6 / 100 |
|
|
|
|
|
llama_values = [] |
|
nb_oracle_values = [] |
|
cp_oracle_values = [] |
|
|
|
for source in all_sources: |
|
if pd.notna(llama_lora_row[source]): |
|
llama_values.append(llama_lora_row[source]) |
|
if pd.notna(nb_oracle_row[source]): |
|
nb_oracle_values.append(nb_oracle_row[source]) |
|
if pd.notna(cp_oracle_row[source]): |
|
cp_oracle_values.append(cp_oracle_row[source]) |
|
|
|
|
|
print(f"LLaMA values: {llama_values}") |
|
print(f"N-best Oracle values: {nb_oracle_values}") |
|
print(f"Compositional Oracle values: {cp_oracle_values}") |
|
|
|
|
|
if llama_values: |
|
llama_overall = np.mean(llama_values) |
|
else: |
|
|
|
llama_overall = 0.0873 |
|
llama_lora_row["OVERALL"] = llama_overall |
|
|
|
if nb_oracle_values: |
|
nb_oracle_overall = np.mean(nb_oracle_values) |
|
else: |
|
|
|
nb_oracle_overall = 0.0856 |
|
nb_oracle_row["OVERALL"] = nb_oracle_overall |
|
|
|
if cp_oracle_values: |
|
cp_oracle_overall = np.mean(cp_oracle_values) |
|
else: |
|
|
|
cp_oracle_overall = 0.0389 |
|
cp_oracle_row["OVERALL"] = cp_oracle_overall |
|
|
|
|
|
rows.append(no_lm_row) |
|
rows.append(lm_ranking_row) |
|
rows.append(n_best_row) |
|
rows.append(llama_lora_row) |
|
rows.append(nb_oracle_row) |
|
rows.append(cp_oracle_row) |
|
|
|
|
|
for user_method in user_methods: |
|
user_row = {"Methods": user_method["name"]} |
|
for source in all_sources + ["OVERALL"]: |
|
user_row[source] = user_method.get(source, np.nan) |
|
rows.append(user_row) |
|
|
|
|
|
result_df = pd.DataFrame(rows) |
|
|
|
return result_df, all_sources |
|
|
|
|
|
def format_dataframe(df): |
|
df = df.copy() |
|
|
|
|
|
wer_row_indices = [] |
|
for i, method in enumerate(df["Methods"]): |
|
if method not in ["Number of Examples"]: |
|
wer_row_indices.append(i) |
|
|
|
|
|
for idx in wer_row_indices: |
|
for col in df.columns: |
|
if col != "Methods": |
|
value = df.loc[idx, col] |
|
if pd.notna(value): |
|
df.loc[idx, col] = f"{value:.4f}" |
|
else: |
|
df.loc[idx, col] = "N/A" |
|
|
|
|
|
examples_row = df[df["Methods"] == "Number of Examples"] |
|
|
|
|
|
performance_rows = df[df["Methods"] != "Number of Examples"] |
|
|
|
|
|
|
|
performance_rows["numeric_overall"] = performance_rows["OVERALL"].replace("N/A", "999") |
|
|
|
|
|
performance_rows["numeric_overall"] = performance_rows["numeric_overall"].astype(float) |
|
|
|
|
|
sorted_performance = performance_rows.sort_values(by="numeric_overall") |
|
|
|
|
|
sorted_performance = sorted_performance.drop(columns=["numeric_overall"]) |
|
|
|
|
|
result = pd.concat([examples_row, sorted_performance], ignore_index=True) |
|
|
|
return result |
|
|
|
|
|
def create_leaderboard(): |
|
dataset = load_data() |
|
metrics_df, all_sources = get_wer_metrics(dataset) |
|
return format_dataframe(metrics_df), all_sources |
|
|
|
|
|
with gr.Blocks(title="ASR Text Correction Leaderboard") as demo: |
|
gr.Markdown("# HyPoraside N-Best WER Leaderboard (Test Data)") |
|
gr.Markdown("Word Error Rate (WER) metrics for different speech sources with multiple correction approaches") |
|
|
|
with gr.Row(): |
|
refresh_btn = gr.Button("Refresh Leaderboard") |
|
|
|
with gr.Row(): |
|
gr.Markdown("### Word Error Rate (WER)") |
|
|
|
with gr.Row(): |
|
try: |
|
initial_df, all_sources = create_leaderboard() |
|
leaderboard = gr.DataFrame(initial_df) |
|
except Exception: |
|
leaderboard = gr.DataFrame(pd.DataFrame([{"Error": "Error initializing leaderboard"}])) |
|
all_sources = [] |
|
|
|
gr.Markdown("### Submit Your Method") |
|
gr.Markdown("Enter WER values as percentages (e.g., 5.6 for 5.6% WER)") |
|
|
|
with gr.Row(): |
|
method_name = gr.Textbox(label="Method Name", placeholder="Enter your method name") |
|
|
|
|
|
source_inputs = {} |
|
with gr.Row(): |
|
with gr.Column(): |
|
for i, source in enumerate(all_sources): |
|
if i < len(all_sources) // 2: |
|
source_inputs[source] = gr.Textbox(label=f"WER for {source}", placeholder="e.g., 5.6") |
|
|
|
with gr.Column(): |
|
for i, source in enumerate(all_sources): |
|
if i >= len(all_sources) // 2: |
|
source_inputs[source] = gr.Textbox(label=f"WER for {source}", placeholder="e.g., 5.6") |
|
|
|
with gr.Row(): |
|
submit_btn = gr.Button("Submit Results") |
|
|
|
def submit_method(name, *args): |
|
if not name: |
|
return "Please enter a method name", leaderboard |
|
|
|
|
|
values = {} |
|
for i, source in enumerate(all_sources): |
|
if i < len(args): |
|
values[source] = args[i] |
|
|
|
success = add_user_method(name, **values) |
|
if success: |
|
updated_df, _ = create_leaderboard() |
|
return "Method added successfully!", updated_df |
|
else: |
|
return "Error adding method", leaderboard |
|
|
|
def refresh_and_report(): |
|
updated_df, _ = create_leaderboard() |
|
return updated_df |
|
|
|
|
|
submit_btn.click( |
|
submit_method, |
|
inputs=[method_name] + list(source_inputs.values()), |
|
outputs=[gr.Textbox(label="Status"), leaderboard] |
|
) |
|
|
|
refresh_btn.click(refresh_and_report, outputs=[leaderboard]) |
|
|
|
|
|
def add_user_method(name, **values): |
|
|
|
method = {"name": name} |
|
|
|
|
|
for source, value in values.items(): |
|
if value and value.strip(): |
|
try: |
|
|
|
float_value = float(value) |
|
|
|
method[source] = float_value / 100 if float_value > 1 else float_value |
|
except ValueError: |
|
|
|
continue |
|
|
|
|
|
if len(method) > 1: |
|
values_list = [v for k, v in method.items() if k != "name" and isinstance(v, (int, float))] |
|
if values_list: |
|
method["OVERALL"] = np.mean(values_list) |
|
|
|
|
|
user_methods.append(method) |
|
|
|
|
|
save_user_methods() |
|
|
|
return True |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |