refines
Browse files
app.py
CHANGED
@@ -6,6 +6,8 @@ from functools import lru_cache
|
|
6 |
import re
|
7 |
from collections import Counter
|
8 |
import editdistance
|
|
|
|
|
9 |
|
10 |
# Cache the dataset loading to avoid reloading on refresh
|
11 |
@lru_cache(maxsize=1)
|
@@ -18,6 +20,37 @@ def load_data():
|
|
18 |
return load_dataset("parquet",
|
19 |
data_files="https://huggingface.co/datasets/GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction/resolve/main/data/test-00000-of-00001.parquet")
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
# Preprocess text for better WER calculation
|
22 |
def preprocess_text(text):
|
23 |
if not text or not isinstance(text, str):
|
@@ -353,12 +386,19 @@ def get_wer_metrics(dataset):
|
|
353 |
rows.append(nb_oracle_row)
|
354 |
rows.append(cp_oracle_row)
|
355 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
# Create DataFrame from rows
|
357 |
result_df = pd.DataFrame(rows)
|
358 |
|
359 |
-
return result_df
|
360 |
|
361 |
-
# Format the dataframe for display
|
362 |
def format_dataframe(df):
|
363 |
df = df.copy()
|
364 |
|
@@ -378,13 +418,35 @@ def format_dataframe(df):
|
|
378 |
else:
|
379 |
df.loc[idx, col] = "N/A"
|
380 |
|
381 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
382 |
|
383 |
# Main function to create the leaderboard
|
384 |
def create_leaderboard():
|
385 |
dataset = load_data()
|
386 |
-
metrics_df = get_wer_metrics(dataset)
|
387 |
-
return format_dataframe(metrics_df)
|
388 |
|
389 |
# Create the Gradio interface
|
390 |
with gr.Blocks(title="ASR Text Correction Leaderboard") as demo:
|
@@ -399,13 +461,56 @@ with gr.Blocks(title="ASR Text Correction Leaderboard") as demo:
|
|
399 |
|
400 |
with gr.Row():
|
401 |
try:
|
402 |
-
initial_df = create_leaderboard()
|
403 |
leaderboard = gr.DataFrame(initial_df)
|
404 |
except Exception:
|
405 |
leaderboard = gr.DataFrame(pd.DataFrame([{"Error": "Error initializing leaderboard"}]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
406 |
|
407 |
def refresh_and_report():
|
408 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
409 |
|
410 |
refresh_btn.click(refresh_and_report, outputs=[leaderboard])
|
411 |
|
|
|
6 |
import re
|
7 |
from collections import Counter
|
8 |
import editdistance
|
9 |
+
import json
|
10 |
+
import os
|
11 |
|
12 |
# Cache the dataset loading to avoid reloading on refresh
|
13 |
@lru_cache(maxsize=1)
|
|
|
20 |
return load_dataset("parquet",
|
21 |
data_files="https://huggingface.co/datasets/GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction/resolve/main/data/test-00000-of-00001.parquet")
|
22 |
|
23 |
+
# Storage for user-submitted methods (in-memory for demo purposes)
|
24 |
+
user_methods = []
|
25 |
+
|
26 |
+
# Data file for persistence
|
27 |
+
USER_DATA_FILE = "user_methods.json"
|
28 |
+
|
29 |
+
# Load user methods from file if exists
|
30 |
+
def load_user_methods():
|
31 |
+
global user_methods
|
32 |
+
if os.path.exists(USER_DATA_FILE):
|
33 |
+
try:
|
34 |
+
with open(USER_DATA_FILE, 'r') as f:
|
35 |
+
user_methods = json.load(f)
|
36 |
+
except Exception as e:
|
37 |
+
print(f"Error loading user methods: {e}")
|
38 |
+
user_methods = []
|
39 |
+
|
40 |
+
# Save user methods to file
|
41 |
+
def save_user_methods():
|
42 |
+
try:
|
43 |
+
with open(USER_DATA_FILE, 'w') as f:
|
44 |
+
json.dump(user_methods, f)
|
45 |
+
except Exception as e:
|
46 |
+
print(f"Error saving user methods: {e}")
|
47 |
+
|
48 |
+
# Try to load user methods at startup
|
49 |
+
try:
|
50 |
+
load_user_methods()
|
51 |
+
except:
|
52 |
+
pass
|
53 |
+
|
54 |
# Preprocess text for better WER calculation
|
55 |
def preprocess_text(text):
|
56 |
if not text or not isinstance(text, str):
|
|
|
386 |
rows.append(nb_oracle_row)
|
387 |
rows.append(cp_oracle_row)
|
388 |
|
389 |
+
# Add user-submitted methods
|
390 |
+
for user_method in user_methods:
|
391 |
+
user_row = {"Methods": user_method["name"]}
|
392 |
+
for source in all_sources + ["OVERALL"]:
|
393 |
+
user_row[source] = user_method.get(source, np.nan)
|
394 |
+
rows.append(user_row)
|
395 |
+
|
396 |
# Create DataFrame from rows
|
397 |
result_df = pd.DataFrame(rows)
|
398 |
|
399 |
+
return result_df, all_sources
|
400 |
|
401 |
+
# Format the dataframe for display, and sort by performance
|
402 |
def format_dataframe(df):
|
403 |
df = df.copy()
|
404 |
|
|
|
418 |
else:
|
419 |
df.loc[idx, col] = "N/A"
|
420 |
|
421 |
+
# Extract the examples row
|
422 |
+
examples_row = df[df["Methods"] == "Number of Examples"]
|
423 |
+
|
424 |
+
# Get the performance rows
|
425 |
+
performance_rows = df[df["Methods"] != "Number of Examples"]
|
426 |
+
|
427 |
+
# Convert the OVERALL column to numeric for sorting
|
428 |
+
# First, replace 'N/A' with a high value (worse than any real WER)
|
429 |
+
performance_rows["numeric_overall"] = performance_rows["OVERALL"].replace("N/A", "999")
|
430 |
+
|
431 |
+
# Convert to float for sorting
|
432 |
+
performance_rows["numeric_overall"] = performance_rows["numeric_overall"].astype(float)
|
433 |
+
|
434 |
+
# Sort by performance (ascending - lower WER is better)
|
435 |
+
sorted_performance = performance_rows.sort_values(by="numeric_overall")
|
436 |
+
|
437 |
+
# Drop the numeric column used for sorting
|
438 |
+
sorted_performance = sorted_performance.drop(columns=["numeric_overall"])
|
439 |
+
|
440 |
+
# Combine the examples row with the sorted performance rows
|
441 |
+
result = pd.concat([examples_row, sorted_performance], ignore_index=True)
|
442 |
+
|
443 |
+
return result
|
444 |
|
445 |
# Main function to create the leaderboard
|
446 |
def create_leaderboard():
|
447 |
dataset = load_data()
|
448 |
+
metrics_df, all_sources = get_wer_metrics(dataset)
|
449 |
+
return format_dataframe(metrics_df), all_sources
|
450 |
|
451 |
# Create the Gradio interface
|
452 |
with gr.Blocks(title="ASR Text Correction Leaderboard") as demo:
|
|
|
461 |
|
462 |
with gr.Row():
|
463 |
try:
|
464 |
+
initial_df, all_sources = create_leaderboard()
|
465 |
leaderboard = gr.DataFrame(initial_df)
|
466 |
except Exception:
|
467 |
leaderboard = gr.DataFrame(pd.DataFrame([{"Error": "Error initializing leaderboard"}]))
|
468 |
+
all_sources = []
|
469 |
+
|
470 |
+
gr.Markdown("### Submit Your Method")
|
471 |
+
gr.Markdown("Enter WER values as percentages (e.g., 5.6 for 5.6% WER)")
|
472 |
+
|
473 |
+
with gr.Row():
|
474 |
+
method_name = gr.Textbox(label="Method Name", placeholder="Enter your method name")
|
475 |
+
|
476 |
+
# Create input fields for each source
|
477 |
+
source_inputs = {}
|
478 |
+
with gr.Row():
|
479 |
+
with gr.Column():
|
480 |
+
for i, source in enumerate(all_sources):
|
481 |
+
if i < len(all_sources) // 2:
|
482 |
+
source_inputs[source] = gr.Textbox(label=f"WER for {source}", placeholder="e.g., 5.6")
|
483 |
+
|
484 |
+
with gr.Column():
|
485 |
+
for i, source in enumerate(all_sources):
|
486 |
+
if i >= len(all_sources) // 2:
|
487 |
+
source_inputs[source] = gr.Textbox(label=f"WER for {source}", placeholder="e.g., 5.6")
|
488 |
+
|
489 |
+
with gr.Row():
|
490 |
+
submit_btn = gr.Button("Submit Results")
|
491 |
+
|
492 |
+
def submit_method(name, **values):
|
493 |
+
if not name:
|
494 |
+
return "Please enter a method name", leaderboard
|
495 |
+
|
496 |
+
success = add_user_method(name, values)
|
497 |
+
if success:
|
498 |
+
updated_df, _ = create_leaderboard()
|
499 |
+
return "Method added successfully!", updated_df
|
500 |
+
else:
|
501 |
+
return "Error adding method", leaderboard
|
502 |
|
503 |
def refresh_and_report():
|
504 |
+
updated_df, _ = create_leaderboard()
|
505 |
+
return updated_df
|
506 |
+
|
507 |
+
# Connect buttons to functions
|
508 |
+
submit_args = [method_name] + list(source_inputs.values())
|
509 |
+
submit_btn.click(
|
510 |
+
submit_method,
|
511 |
+
inputs=[method_name] + list(source_inputs.values()),
|
512 |
+
outputs=[gr.Textbox(label="Status"), leaderboard]
|
513 |
+
)
|
514 |
|
515 |
refresh_btn.click(refresh_and_report, outputs=[leaderboard])
|
516 |
|