huckiyang commited on
Commit
0035b5e
·
1 Parent(s): 6a06457
Files changed (1) hide show
  1. app.py +112 -7
app.py CHANGED
@@ -6,6 +6,8 @@ from functools import lru_cache
6
  import re
7
  from collections import Counter
8
  import editdistance
 
 
9
 
10
  # Cache the dataset loading to avoid reloading on refresh
11
  @lru_cache(maxsize=1)
@@ -18,6 +20,37 @@ def load_data():
18
  return load_dataset("parquet",
19
  data_files="https://huggingface.co/datasets/GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction/resolve/main/data/test-00000-of-00001.parquet")
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # Preprocess text for better WER calculation
22
  def preprocess_text(text):
23
  if not text or not isinstance(text, str):
@@ -353,12 +386,19 @@ def get_wer_metrics(dataset):
353
  rows.append(nb_oracle_row)
354
  rows.append(cp_oracle_row)
355
 
 
 
 
 
 
 
 
356
  # Create DataFrame from rows
357
  result_df = pd.DataFrame(rows)
358
 
359
- return result_df
360
 
361
- # Format the dataframe for display
362
  def format_dataframe(df):
363
  df = df.copy()
364
 
@@ -378,13 +418,35 @@ def format_dataframe(df):
378
  else:
379
  df.loc[idx, col] = "N/A"
380
 
381
- return df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
 
383
  # Main function to create the leaderboard
384
  def create_leaderboard():
385
  dataset = load_data()
386
- metrics_df = get_wer_metrics(dataset)
387
- return format_dataframe(metrics_df)
388
 
389
  # Create the Gradio interface
390
  with gr.Blocks(title="ASR Text Correction Leaderboard") as demo:
@@ -399,13 +461,56 @@ with gr.Blocks(title="ASR Text Correction Leaderboard") as demo:
399
 
400
  with gr.Row():
401
  try:
402
- initial_df = create_leaderboard()
403
  leaderboard = gr.DataFrame(initial_df)
404
  except Exception:
405
  leaderboard = gr.DataFrame(pd.DataFrame([{"Error": "Error initializing leaderboard"}]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
 
407
  def refresh_and_report():
408
- return create_leaderboard()
 
 
 
 
 
 
 
 
 
409
 
410
  refresh_btn.click(refresh_and_report, outputs=[leaderboard])
411
 
 
6
  import re
7
  from collections import Counter
8
  import editdistance
9
+ import json
10
+ import os
11
 
12
  # Cache the dataset loading to avoid reloading on refresh
13
  @lru_cache(maxsize=1)
 
20
  return load_dataset("parquet",
21
  data_files="https://huggingface.co/datasets/GenSEC-LLM/SLT-Task1-Post-ASR-Text-Correction/resolve/main/data/test-00000-of-00001.parquet")
22
 
23
+ # Storage for user-submitted methods (in-memory for demo purposes)
24
+ user_methods = []
25
+
26
+ # Data file for persistence
27
+ USER_DATA_FILE = "user_methods.json"
28
+
29
+ # Load user methods from file if exists
30
+ def load_user_methods():
31
+ global user_methods
32
+ if os.path.exists(USER_DATA_FILE):
33
+ try:
34
+ with open(USER_DATA_FILE, 'r') as f:
35
+ user_methods = json.load(f)
36
+ except Exception as e:
37
+ print(f"Error loading user methods: {e}")
38
+ user_methods = []
39
+
40
+ # Save user methods to file
41
+ def save_user_methods():
42
+ try:
43
+ with open(USER_DATA_FILE, 'w') as f:
44
+ json.dump(user_methods, f)
45
+ except Exception as e:
46
+ print(f"Error saving user methods: {e}")
47
+
48
+ # Try to load user methods at startup
49
+ try:
50
+ load_user_methods()
51
+ except:
52
+ pass
53
+
54
  # Preprocess text for better WER calculation
55
  def preprocess_text(text):
56
  if not text or not isinstance(text, str):
 
386
  rows.append(nb_oracle_row)
387
  rows.append(cp_oracle_row)
388
 
389
+ # Add user-submitted methods
390
+ for user_method in user_methods:
391
+ user_row = {"Methods": user_method["name"]}
392
+ for source in all_sources + ["OVERALL"]:
393
+ user_row[source] = user_method.get(source, np.nan)
394
+ rows.append(user_row)
395
+
396
  # Create DataFrame from rows
397
  result_df = pd.DataFrame(rows)
398
 
399
+ return result_df, all_sources
400
 
401
+ # Format the dataframe for display, and sort by performance
402
  def format_dataframe(df):
403
  df = df.copy()
404
 
 
418
  else:
419
  df.loc[idx, col] = "N/A"
420
 
421
+ # Extract the examples row
422
+ examples_row = df[df["Methods"] == "Number of Examples"]
423
+
424
+ # Get the performance rows
425
+ performance_rows = df[df["Methods"] != "Number of Examples"]
426
+
427
+ # Convert the OVERALL column to numeric for sorting
428
+ # First, replace 'N/A' with a high value (worse than any real WER)
429
+ performance_rows["numeric_overall"] = performance_rows["OVERALL"].replace("N/A", "999")
430
+
431
+ # Convert to float for sorting
432
+ performance_rows["numeric_overall"] = performance_rows["numeric_overall"].astype(float)
433
+
434
+ # Sort by performance (ascending - lower WER is better)
435
+ sorted_performance = performance_rows.sort_values(by="numeric_overall")
436
+
437
+ # Drop the numeric column used for sorting
438
+ sorted_performance = sorted_performance.drop(columns=["numeric_overall"])
439
+
440
+ # Combine the examples row with the sorted performance rows
441
+ result = pd.concat([examples_row, sorted_performance], ignore_index=True)
442
+
443
+ return result
444
 
445
  # Main function to create the leaderboard
446
  def create_leaderboard():
447
  dataset = load_data()
448
+ metrics_df, all_sources = get_wer_metrics(dataset)
449
+ return format_dataframe(metrics_df), all_sources
450
 
451
  # Create the Gradio interface
452
  with gr.Blocks(title="ASR Text Correction Leaderboard") as demo:
 
461
 
462
  with gr.Row():
463
  try:
464
+ initial_df, all_sources = create_leaderboard()
465
  leaderboard = gr.DataFrame(initial_df)
466
  except Exception:
467
  leaderboard = gr.DataFrame(pd.DataFrame([{"Error": "Error initializing leaderboard"}]))
468
+ all_sources = []
469
+
470
+ gr.Markdown("### Submit Your Method")
471
+ gr.Markdown("Enter WER values as percentages (e.g., 5.6 for 5.6% WER)")
472
+
473
+ with gr.Row():
474
+ method_name = gr.Textbox(label="Method Name", placeholder="Enter your method name")
475
+
476
+ # Create input fields for each source
477
+ source_inputs = {}
478
+ with gr.Row():
479
+ with gr.Column():
480
+ for i, source in enumerate(all_sources):
481
+ if i < len(all_sources) // 2:
482
+ source_inputs[source] = gr.Textbox(label=f"WER for {source}", placeholder="e.g., 5.6")
483
+
484
+ with gr.Column():
485
+ for i, source in enumerate(all_sources):
486
+ if i >= len(all_sources) // 2:
487
+ source_inputs[source] = gr.Textbox(label=f"WER for {source}", placeholder="e.g., 5.6")
488
+
489
+ with gr.Row():
490
+ submit_btn = gr.Button("Submit Results")
491
+
492
+ def submit_method(name, **values):
493
+ if not name:
494
+ return "Please enter a method name", leaderboard
495
+
496
+ success = add_user_method(name, values)
497
+ if success:
498
+ updated_df, _ = create_leaderboard()
499
+ return "Method added successfully!", updated_df
500
+ else:
501
+ return "Error adding method", leaderboard
502
 
503
  def refresh_and_report():
504
+ updated_df, _ = create_leaderboard()
505
+ return updated_df
506
+
507
+ # Connect buttons to functions
508
+ submit_args = [method_name] + list(source_inputs.values())
509
+ submit_btn.click(
510
+ submit_method,
511
+ inputs=[method_name] + list(source_inputs.values()),
512
+ outputs=[gr.Textbox(label="Status"), leaderboard]
513
+ )
514
 
515
  refresh_btn.click(refresh_and_report, outputs=[leaderboard])
516